From 0860f9d0848eed4bc6afdcce93b7cb816c346325 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 20 Feb 2025 11:10:44 -0800 Subject: [PATCH 01/65] introduce table constraint --- .../resources/error/error-conditions.json | 22 +++++ .../sql/catalyst/parser/SqlBaseParser.g4 | 9 ++ .../sql/connector/catalog/Constraint.java | 60 +++++++++++++ .../sql/connector/catalog/TableChange.java | 89 +++++++++++++++++++ .../sql/catalyst/analysis/CheckAnalysis.scala | 22 +++++ .../sql/catalyst/parser/AstBuilder.scala | 20 +++++ .../plans/logical/v2AlterTableCommands.scala | 36 +++++++- 7 files changed, 256 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f1012edd2de2..727f06be8914 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4138,6 +4138,28 @@ }, "sqlState" : "0A000" }, + "NOT_SUPPORTED_TABLE_CONSTRAINT": { + "message" : [ + "Table constraint is not supported: " + ], + "subClass" : { + "INVALID_V2_PREDICATE" : { + "message" : [ + "cannot convert to data source V2 predicate." + ] + }, + "NONDETERMINISTIC" : { + "message" : [ + "nondeterministic expression." + ] + }, + "UNRESOLVED" : { + "message" : [ + "cannot resolve expression." + ] + } + } + }, "NOT_UNRESOLVED_ENCODER" : { "message" : [ "Unresolved encoder expected, but was found." diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 59a0b1ce7a3c..86d9d7b2d0f2 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -261,6 +261,11 @@ statement | ALTER TABLE identifierReference (clusterBySpec | CLUSTER BY NONE) #alterClusterBy | ALTER TABLE identifierReference collationSpec #alterTableCollation + | ALTER TABLE identifierReference ADD CONSTRAINT name=identifier + constraint #addTableConstraint + | ALTER TABLE identifierReference + DROP CONSTRAINT (IF EXISTS)? name=identifier + (RESTRICT | CASCADE)? #dropTableConstraint | DROP TABLE (IF EXISTS)? identifierReference PURGE? #dropTable | DROP VIEW (IF EXISTS)? identifierReference #dropView | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? @@ -1516,6 +1521,10 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; +constraint + : CHECK '(' booleanExpression ')' #checkConstraint + ; + alterColumnSpecList : alterColumnSpec (COMMA alterColumnSpec)* ; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java new file mode 100644 index 000000000000..263c83fbf13e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.sql.connector.expressions.filter.Predicate; + +public interface Constraint { + String name(); // either assigned by the user or auto generated + String description(); // used in toString() + String toDDL(); // used in EXPLAIN/DESCRIBE/SHOW CREATE TABLE + boolean rely(); // indicates whether the constraint is believed to be true + boolean enforced(); // indicates whether the constraint must be enforced + + static Constraint check(String name, Predicate predicate) { + return new Check(name, predicate); + } + + final class Check implements Constraint { + private final String name; + private final Predicate predicate; + private Check(String name, Predicate predicate) { + this.name = name; + this.predicate = predicate; + } + @Override public String name() { + return name; + } + + @Override public String description() { + return "check constraint"; + } + + @Override + public String toDDL() { + return "Check (" + predicate.toString() + ")"; + } + + @Override public boolean rely() { + return true; + } + + @Override public boolean enforced() { + return true; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index d7a51c519e09..beaa4e4afaa8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -260,6 +260,24 @@ static TableChange clusterBy(NamedReference[] clusteringColumns) { return new ClusterBy(clusteringColumns); } + /** + * Create a TableChange for adding a new Table Constraint + */ + static TableChange addCheckConstraint(Constraint constraint, Boolean validate) { + return new AddCheckConstraint(constraint, validate); + } + + /** + * Create a TableChange for dropping a Table Constraint + */ + static TableChange dropConstraint(String name, Boolean ifExists, Boolean cascade) { + DropConstraintMode mode = DropConstraintMode.RESTRICT; + if (cascade) { + mode = DropConstraintMode.CASCADE; + } + return new DropConstraint(name, ifExists, mode); + } + /** * A TableChange to set a table property. *

@@ -787,4 +805,75 @@ public int hashCode() { return Arrays.hashCode(clusteringColumns); } } + + /** A TableChange to alter table and add a constraint. */ + final class AddCheckConstraint implements TableChange { + private final Constraint constraint; + private final boolean validate; + + private AddCheckConstraint(Constraint constraint, boolean validate) { + this.constraint = constraint; + this.validate = validate; + } + + public Constraint getConstraint() { + return constraint; + } + + public boolean isValidate() { + return validate; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AddCheckConstraint that = (AddCheckConstraint) o; + return constraint.equals(that.constraint) && validate == that.validate; + } + + @Override + public int hashCode() { + return Objects.hash(constraint, validate); + } + } + + enum DropConstraintMode { RESTRICT, CASCADE } + + /** A TableChange to alter table and drop a constraint. */ + final class DropConstraint implements TableChange { + private final String name; + private final boolean ifExists; + private final DropConstraintMode mode; + + private DropConstraint(String name, boolean ifExists, DropConstraintMode mode) { + this.name = name; + this.ifExists = ifExists; + this.mode = mode; + } + + public String getName() { + return name; + } + + public boolean isIfExists() { + return ifExists; + } + + public DropConstraintMode getMode() { + return mode; + } + + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DropConstraint that = (DropConstraint) o; + return that.name.equals(name) && that.ifExists == ifExists && mode == that.mode; + } + + @Override + public int hashCode() { + return Objects.hash(name, ifExists, mode); + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1b45fcde9126..22b6b14176e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1190,6 +1190,28 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString } case _ => } + + case addConstraint @ AlterTableAddConstraint(table: ResolvedTable, name, constraintExpr) => + if (!constraintExpr.resolved) { + alter.failAnalysis( + errorClass = "NOT_SUPPORTED_TABLE_CONSTRAINT.UNRESOLVED", + messageParameters = Map("expression" -> constraintExpr.toString) + ) + } + + if (!constraintExpr.deterministic) { + alter.failAnalysis( + errorClass = "NOT_SUPPORTED_TABLE_CONSTRAINT.NONDETERMINISTIC", + messageParameters = Map("expression" -> constraintExpr.toString) + ) + } + + if (addConstraint.predicate.isEmpty) { + alter.failAnalysis( + errorClass = "NOT_SUPPORTED_TABLE_CONSTRAINT.INVALID_V2_PREDICATE", + messageParameters = Map("expression" -> constraintExpr.toString) + ) + } case _ => } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd7af021d8ff..40c37aeb728b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5237,6 +5237,26 @@ class AstBuilder extends DataTypeAstBuilder AlterTableCollation(table, visitCollationSpec(ctx.collationSpec())) } + override def visitAddTableConstraint(ctx: AddTableConstraintContext): LogicalPlan = + withOrigin(ctx) { + val table = createUnresolvedTable( + ctx.identifierReference, "ALTER TABLE ... ADD CONSTRAINT") + val constraint = + expression(ctx.constraint().asInstanceOf[CheckConstraintContext].booleanExpression()) + AlterTableAddConstraint(table, ctx.name.getText, constraint) + } + + override def visitDropTableConstraint(ctx: DropTableConstraintContext): LogicalPlan = + withOrigin(ctx) { + val table = createUnresolvedTable( + ctx.identifierReference, "ALTER TABLE ... DROP CONSTRAINT") + AlterTableDropConstraint( + table, + ctx.name.getText, + ifExists = ctx.EXISTS() != null, + cascade = ctx.CASCADE() != null) + } + /** * Parse [[SetViewProperties]] or [[SetTableProperties]] commands. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index a0def801ee6f..c2368bda3697 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, Resolve import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ClusterBySpec import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} -import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, TypeUtils} -import org.apache.spark.sql.connector.catalog.{TableCatalog, TableChange} +import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, TypeUtils, V2ExpressionBuilder} +import org.apache.spark.sql.connector.catalog.{Constraint, TableCatalog, TableChange} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.DataType import org.apache.spark.util.ArrayImplicits._ @@ -288,3 +288,35 @@ case class AlterTableCollation( protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild) } + +/** + * The logical plan of the ALTER TABLE ... ADD CONSTRAINT command. + */ +case class AlterTableAddConstraint( + table: LogicalPlan, + name: String, + constraintExpr: Expression) extends AlterTableCommand { + + lazy val predicate = new V2ExpressionBuilder(constraintExpr, true).buildPredicate() + + override def changes: Seq[TableChange] = { + val constraint = Constraint.check(name, predicate.get) + Seq(TableChange.addCheckConstraint(constraint, constraint.enforced())) + } + + protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild) +} + +/** + * The logical plan of the ALTER TABLE ... DROP CONSTRAINT command. + */ +case class AlterTableDropConstraint( + table: LogicalPlan, + name: String, + ifExists: Boolean, + cascade: Boolean) extends AlterTableCommand { + override def changes: Seq[TableChange] = + Seq(TableChange.dropConstraint(name, ifExists, cascade)) + + protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild) +} From 0b222ba73afa215b3235f5cc7fd7e29ab78f2e35 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 20 Feb 2025 13:37:17 -0800 Subject: [PATCH 02/65] add AlterTableAddConstraintParseSuite and AlterTableDropConstraintParseSuite --- .../AlterTableAddConstraintParseSuite.scala | 64 +++++++++++++ .../AlterTableDropConstraintParseSuite.scala | 94 +++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala new file mode 100644 index 000000000000..c266d173d651 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedTable} +import org.apache.spark.sql.catalyst.expressions.{GreaterThan, Literal} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.AlterTableAddConstraint +import org.apache.spark.sql.test.SharedSparkSession + +class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSession { + + test("Add check constraint") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AlterTableAddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + "c1", + GreaterThan(UnresolvedAttribute("d"), Literal(0))) + comparePlans(parsed, expected) + } + + test("Add invalid check constraint name") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1-c3 CHECK (d > 0) + |""".stripMargin + val msg = intercept[ParseException] { + parsePlan(sql) + }.getMessage + assert(msg.contains("Syntax error at or near '-'.")) + } + + test("Add invalid check constraint expression") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d >) + |""".stripMargin + val msg = intercept[ParseException] { + parsePlan(sql) + }.getMessage + assert(msg.contains("Syntax error at or near ')'")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala new file mode 100644 index 000000000000..98ff55674f34 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedTable} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.AlterTableDropConstraint +import org.apache.spark.sql.test.SharedSparkSession + +class AlterTableDropConstraintParseSuite extends AnalysisTest with SharedSparkSession { + + test("Drop constraint") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1" + val parsed = parsePlan(sql) + val expected = AlterTableDropConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... DROP CONSTRAINT"), + "c1", + ifExists = false, + cascade = false) + comparePlans(parsed, expected) + } + + test("Drop constraint if exists") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT IF EXISTS c1" + val parsed = parsePlan(sql) + val expected = AlterTableDropConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... DROP CONSTRAINT"), + "c1", + ifExists = true, + cascade = false) + comparePlans(parsed, expected) + } + + test("Drop constraint cascade") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1 CASCADE" + val parsed = parsePlan(sql) + val expected = AlterTableDropConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... DROP CONSTRAINT"), + "c1", + ifExists = false, + cascade = true) + comparePlans(parsed, expected) + } + + test("Drop constraint restrict") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1 RESTRICT" + val parsed = parsePlan(sql) + val expected = AlterTableDropConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... DROP CONSTRAINT"), + "c1", + ifExists = false, + cascade = false) + comparePlans(parsed, expected) + } + + test("Drop constraint with invalid name") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1-c3 ENFORCE" + val msg = intercept[ParseException] { + parsePlan(sql) + }.getMessage + assert(msg.contains("Syntax error at or near '-'")) + } + + test("Drop constraint with invalid mode") { + val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1 ENFORCE" + val msg = intercept[ParseException] { + parsePlan(sql) + }.getMessage + assert(msg.contains("Syntax error at or near 'ENFORCE': extra input 'ENFORCE'.")) + } +} From 40d27aec05381427c296bc6599ea05020fa46d20 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 21 Feb 2025 13:56:15 -0800 Subject: [PATCH 03/65] update error conditions and add tests --- .../resources/error/error-conditions.json | 40 ++++---- .../sql/catalyst/analysis/CheckAnalysis.scala | 20 ++-- .../command/v2/CheckConstraintSuite.scala | 95 +++++++++++++++++++ 3 files changed, 123 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 727f06be8914..d20b110e665a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2303,6 +2303,24 @@ ], "sqlState" : "22P03" }, + "INVALID_CHECK_CONSTRAINT": { + "message" : [ + "The check constraint expression is invalid." + ], + "subClass" : { + "INVALID_V2_PREDICATE" : { + "message" : [ + "It cannot be converted to a data source V2 predicate." + ] + }, + "NONDETERMINISTIC" : { + "message" : [ + "It contains nondeterministic expression." + ] + } + }, + "sqlState": "42621" + }, "INVALID_COLUMN_NAME_AS_PATH" : { "message" : [ "The datasource cannot save the column because its name contains some characters that are not allowed in file paths. Please, use an alias to rename it." @@ -4138,28 +4156,6 @@ }, "sqlState" : "0A000" }, - "NOT_SUPPORTED_TABLE_CONSTRAINT": { - "message" : [ - "Table constraint is not supported: " - ], - "subClass" : { - "INVALID_V2_PREDICATE" : { - "message" : [ - "cannot convert to data source V2 predicate." - ] - }, - "NONDETERMINISTIC" : { - "message" : [ - "nondeterministic expression." - ] - }, - "UNRESOLVED" : { - "message" : [ - "cannot resolve expression." - ] - } - } - }, "NOT_UNRESOLVED_ENCODER" : { "message" : [ "Unresolved encoder expected, but was found." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 22b6b14176e6..7cac1316bb26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1191,25 +1191,25 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString case _ => } - case addConstraint @ AlterTableAddConstraint(table: ResolvedTable, name, constraintExpr) => + case addConstraint @ AlterTableAddConstraint(table: ResolvedTable, _, constraintExpr) => if (!constraintExpr.resolved) { - alter.failAnalysis( - errorClass = "NOT_SUPPORTED_TABLE_CONSTRAINT.UNRESOLVED", - messageParameters = Map("expression" -> constraintExpr.toString) + constraintExpr.failAnalysis( + errorClass = "INVALID_CHECK_CONSTRAINT.UNRESOLVED", + messageParameters = Map.empty ) } if (!constraintExpr.deterministic) { - alter.failAnalysis( - errorClass = "NOT_SUPPORTED_TABLE_CONSTRAINT.NONDETERMINISTIC", - messageParameters = Map("expression" -> constraintExpr.toString) + constraintExpr.failAnalysis( + errorClass = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", + messageParameters = Map.empty ) } if (addConstraint.predicate.isEmpty) { - alter.failAnalysis( - errorClass = "NOT_SUPPORTED_TABLE_CONSTRAINT.INVALID_V2_PREDICATE", - messageParameters = Map("expression" -> constraintExpr.toString) + constraintExpr.failAnalysis( + errorClass = "INVALID_CHECK_CONSTRAINT.INVALID_V2_PREDICATE", + messageParameters = Map.empty ) } case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala new file mode 100644 index 000000000000..ef64a0aca817 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} + +class CheckConstraintSuite extends QueryTest with CommandSuiteBase { + test("Nondeterministic expression") { + withTable("t") { + sql("create table t(i double) using parquet") + val query = + """ + |ALTER TABLE t ADD CONSTRAINT c1 CHECK (i > rand(0)) + |""".stripMargin + val error = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = error, + condition = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", + sqlState = "42621", + parameters = Map.empty, + context = ExpectedContext( + fragment = "i > rand(0)", + start = 40, + stop = 50 + ) + ) + } + } + + test("Expression referring a column of another table") { + withTable("t", "t2") { + sql("create table t(i double) using parquet") + sql("create table t2(j string) using parquet") + val query = + """ + |ALTER TABLE t ADD CONSTRAINT c1 CHECK (len(t2.j) > 0) + |""".stripMargin + val error = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = error, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`t2`.`j`", "proposal" -> "`t`.`i`"), + context = ExpectedContext( + fragment = "t2.j", + start = 44, + stop = 47 + ) + ) + } + } + + test("Can't convert expression to V2 predicate") { + withTable("t") { + sql("create table t(i string) using parquet") + val query = + """ + |ALTER TABLE t ADD CONSTRAINT c1 CHECK (from_json(i, 'a INT').a > 1) + |""".stripMargin + val error = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = error, + condition = "INVALID_CHECK_CONSTRAINT.INVALID_V2_PREDICATE", + sqlState = "42621", + parameters = Map.empty, + context = ExpectedContext( + fragment = "from_json(i, 'a INT').a > 1", + start = 40, + stop = 66 + ) + ) + } + } +} From f244c752366a35ea1dd2bf13b09e22816dae606e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 21 Feb 2025 14:00:35 -0800 Subject: [PATCH 04/65] rename logical plans --- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../catalyst/plans/logical/v2AlterTableCommands.scala | 4 ++-- .../command/AlterTableAddConstraintParseSuite.scala | 4 ++-- .../command/AlterTableDropConstraintParseSuite.scala | 10 +++++----- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7cac1316bb26..c94f5476d30b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1191,7 +1191,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString case _ => } - case addConstraint @ AlterTableAddConstraint(table: ResolvedTable, _, constraintExpr) => + case addConstraint @ AddConstraint(table: ResolvedTable, _, constraintExpr) => if (!constraintExpr.resolved) { constraintExpr.failAnalysis( errorClass = "INVALID_CHECK_CONSTRAINT.UNRESOLVED", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 40c37aeb728b..337729af8d7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5243,14 +5243,14 @@ class AstBuilder extends DataTypeAstBuilder ctx.identifierReference, "ALTER TABLE ... ADD CONSTRAINT") val constraint = expression(ctx.constraint().asInstanceOf[CheckConstraintContext].booleanExpression()) - AlterTableAddConstraint(table, ctx.name.getText, constraint) + AddConstraint(table, ctx.name.getText, constraint) } override def visitDropTableConstraint(ctx: DropTableConstraintContext): LogicalPlan = withOrigin(ctx) { val table = createUnresolvedTable( ctx.identifierReference, "ALTER TABLE ... DROP CONSTRAINT") - AlterTableDropConstraint( + DropConstraint( table, ctx.name.getText, ifExists = ctx.EXISTS() != null, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index c2368bda3697..10176ad4b74d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -292,7 +292,7 @@ case class AlterTableCollation( /** * The logical plan of the ALTER TABLE ... ADD CONSTRAINT command. */ -case class AlterTableAddConstraint( +case class AddConstraint( table: LogicalPlan, name: String, constraintExpr: Expression) extends AlterTableCommand { @@ -310,7 +310,7 @@ case class AlterTableAddConstraint( /** * The logical plan of the ALTER TABLE ... DROP CONSTRAINT command. */ -case class AlterTableDropConstraint( +case class DropConstraint( table: LogicalPlan, name: String, ifExists: Boolean, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index c266d173d651..73f99235f6e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{GreaterThan, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.AlterTableAddConstraint +import org.apache.spark.sql.catalyst.plans.logical.AddConstraint import org.apache.spark.sql.test.SharedSparkSession class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSession { @@ -31,7 +31,7 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) |""".stripMargin val parsed = parsePlan(sql) - val expected = AlterTableAddConstraint( + val expected = AddConstraint( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... ADD CONSTRAINT"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala index 98ff55674f34..eaf51a980324 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropConstraintParseSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedTable} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.AlterTableDropConstraint +import org.apache.spark.sql.catalyst.plans.logical.DropConstraint import org.apache.spark.sql.test.SharedSparkSession class AlterTableDropConstraintParseSuite extends AnalysisTest with SharedSparkSession { @@ -27,7 +27,7 @@ class AlterTableDropConstraintParseSuite extends AnalysisTest with SharedSparkSe test("Drop constraint") { val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1" val parsed = parsePlan(sql) - val expected = AlterTableDropConstraint( + val expected = DropConstraint( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... DROP CONSTRAINT"), @@ -40,7 +40,7 @@ class AlterTableDropConstraintParseSuite extends AnalysisTest with SharedSparkSe test("Drop constraint if exists") { val sql = "ALTER TABLE a.b.c DROP CONSTRAINT IF EXISTS c1" val parsed = parsePlan(sql) - val expected = AlterTableDropConstraint( + val expected = DropConstraint( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... DROP CONSTRAINT"), @@ -53,7 +53,7 @@ class AlterTableDropConstraintParseSuite extends AnalysisTest with SharedSparkSe test("Drop constraint cascade") { val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1 CASCADE" val parsed = parsePlan(sql) - val expected = AlterTableDropConstraint( + val expected = DropConstraint( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... DROP CONSTRAINT"), @@ -66,7 +66,7 @@ class AlterTableDropConstraintParseSuite extends AnalysisTest with SharedSparkSe test("Drop constraint restrict") { val sql = "ALTER TABLE a.b.c DROP CONSTRAINT c1 RESTRICT" val parsed = parsePlan(sql) - val expected = AlterTableDropConstraint( + val expected = DropConstraint( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... DROP CONSTRAINT"), From b60e100b2a8158bc070c2b63ce120acadd903415 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 21 Feb 2025 16:31:41 -0800 Subject: [PATCH 05/65] add methods in CatalogV2Util and tests --- .../resources/error/error-conditions.json | 14 ++++++ .../sql/connector/catalog/Constraint.java | 2 +- .../spark/sql/connector/catalog/Table.java | 5 +++ .../sql/connector/catalog/CatalogV2Util.scala | 42 +++++++++++++++++- .../sql/connector/catalog/InMemoryTable.scala | 3 +- .../catalog/InMemoryTableCatalog.scala | 8 +++- .../command/v2/CheckConstraintSuite.scala | 44 ++++++++++++++++++- .../command/v2/CommandSuiteBase.scala | 9 +++- 8 files changed, 121 insertions(+), 6 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index d20b110e665a..c1abaddc1f39 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -788,6 +788,20 @@ }, "sqlState" : "XX000" }, + "CONSTRAINT_ALREADY_EXISTS" : { + "message" : [ + "Constraint '' already exists. Please delete the old constraint first.", + "Old constraint:", + "" + ], + "sqlState" : "42710" + }, + "CONSTRAINT_DOES_NOT_EXIST" : { + "message" : [ + "Cannot drop nonexistent constraint from table ." + ], + "sqlState" : "42704" + }, "CONVERSION_INVALID_INPUT" : { "message" : [ "The value () cannot be converted to because it is malformed. Correct the value as per the syntax, or change its format. Use to tolerate malformed input and return NULL instead." diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java index 263c83fbf13e..01692d5f9abd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java @@ -46,7 +46,7 @@ private Check(String name, Predicate predicate) { @Override public String toDDL() { - return "Check (" + predicate.toString() + ")"; + return "CHECK (" + predicate.toString() + ")"; } @Override public boolean rely() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java index d5eb03dcf94d..419cc01793d3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java @@ -83,4 +83,9 @@ default Map properties() { * Returns the set of capabilities for this table. */ Set capabilities(); + + /** + * Returns the constraints for this table. + */ + default Constraint[] constraints() { return new Constraint[0]; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 97cc263c56c5..acd3a1380732 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.catalog import java.util -import java.util.Collections +import java.util.{Collections, Locale} import scala.jdk.CollectionConverters._ @@ -194,6 +194,46 @@ private[sql] object CatalogV2Util { newPartitioning } + /** + * Apply Table Constraints changes to an existing set of constraints and return the result. + */ + def applyConstraintChanges( + table: Table, + changes: Seq[TableChange]): Array[Constraint] = { + val constraints = table.constraints() + changes.foldLeft(constraints) { (constraints, change) => + change match { + case add: AddCheckConstraint => + val newConstraint = add.getConstraint + val existingConstraint = + constraints.find( + _.name.toLowerCase(Locale.ROOT) == newConstraint.name().toLowerCase(Locale.ROOT)) + if (existingConstraint.isDefined) { + throw new AnalysisException( + errorClass = "CONSTRAINT_ALREADY_EXISTS", + messageParameters = + Map("constraintName" -> existingConstraint.get.name, + "oldConstraint" -> existingConstraint.get.toDDL)) + } + constraints :+ newConstraint + + case drop: DropConstraint => + val existingConstraint = constraints.find(_.name == drop.getName) + if (existingConstraint.isEmpty && !drop.isIfExists) { + throw new AnalysisException( + errorClass = "CONSTRAINT_DOES_NOT_EXIST", + messageParameters = + Map("constraintName" -> drop.getName, "tableName" -> table.name())) + } + constraints.filterNot(_.name == drop.getName) + + case _ => + // ignore non-constraint changes + constraints + } + }.toArray + } + /** * Apply schema changes to a schema and return the result. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index c27b8fea059f..d3eea9437a66 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -40,7 +40,8 @@ class InMemoryTable( numPartitions: Option[Int] = None, advisoryPartitionSize: Option[Long] = None, isDistributionStrictlyRequired: Boolean = true, - override val numRowsPerSplit: Int = Int.MaxValue) + override val numRowsPerSplit: Int = Int.MaxValue, + override val constraints: Array[Constraint] = Array.empty) extends InMemoryBaseTable(name, schema, partitioning, properties, distribution, ordering, numPartitions, advisoryPartitionSize, isDistributionStrictlyRequired, numRowsPerSplit) with SupportsDelete { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 56ed3bb243e1..ca1f7737c2c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -124,13 +124,19 @@ class BasicInMemoryTableCatalog extends TableCatalog { val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) val schema = CatalogV2Util.applySchemaChanges(table.schema, changes, None, "ALTER TABLE") val finalPartitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes) + val constraints = CatalogV2Util.applyConstraintChanges(table, changes) // fail if the last column in the schema was dropped if (schema.fields.isEmpty) { throw new IllegalArgumentException(s"Cannot drop all fields") } - val newTable = new InMemoryTable(table.name, schema, finalPartitioning, properties) + val newTable = new InMemoryTable( + table.name, + schema, + finalPartitioning, + properties, + constraints = constraints) .withData(table.data) tables.put(ident, newTable) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index ef64a0aca817..39bea8a4a3e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.connector.catalog.Constraint.Check +import org.apache.spark.sql.execution.command.DDLCommandTestUtils -class CheckConstraintSuite extends QueryTest with CommandSuiteBase { +class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { test("Nondeterministic expression") { withTable("t") { sql("create table t(i double) using parquet") @@ -92,4 +94,44 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase { ) } } + + test("Add check constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT c1 CHECK (id > 0)") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + assert(table.constraints.head.isInstanceOf[Check]) + val constraint = table.constraints.head.asInstanceOf[Check] + assert(constraint.name() == "c1") + assert(constraint.rely()) + assert(constraint.enforced()) + assert(constraint.toDDL == "CHECK (id > 0)") + } + } + + test("Add duplicated check constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT abc CHECK (id > 0)") + // Constraint names are case-insensitive + Seq("abc", "ABC").foreach { name => + val error = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD CONSTRAINT $name CHECK (id > 0)") + } + checkError( + exception = error, + condition = "CONSTRAINT_ALREADY_EXISTS", + sqlState = "42710", + parameters = Map("constraintName" -> "abc", "oldConstraint" -> "CHECK (id > 0)") + ) + } + } + } + + override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala index 6ba60e245f9b..8e7a4f55fc5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.ResolvePartitionSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryCatalog, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryCatalog, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTable, InMemoryTableCatalog} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.ArrayImplicits._ @@ -64,4 +64,11 @@ trait CommandSuiteBase extends SharedSparkSession { assert(partMetadata.containsKey("location")) assert(partMetadata.get("location") === expected) } + + def loadTable(catalog: String, schema : String, table: String): InMemoryTable = { + import CatalogV2Implicits._ + val catalogPlugin = spark.sessionState.catalogManager.catalog(catalog) + catalogPlugin.asTableCatalog.loadTable(Identifier.of(Array(schema), table)) + .asInstanceOf[InMemoryTable] + } } From d2b7f0a42e33cbcdd5a73d47ba390fa4cfb7b0c4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 24 Feb 2025 11:41:42 -0800 Subject: [PATCH 06/65] add DropConstraintSuite --- .../sql/connector/catalog/CatalogV2Util.scala | 11 ++- .../command/v2/CheckConstraintSuite.scala | 4 +- .../command/v2/DropConstraintSuite.scala | 93 +++++++++++++++++++ 3 files changed, 102 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index acd3a1380732..fdf5693ae34d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -201,13 +201,16 @@ private[sql] object CatalogV2Util { table: Table, changes: Seq[TableChange]): Array[Constraint] = { val constraints = table.constraints() + + def findExistingConstraint(name: String): Option[Constraint] = { + constraints.find(_.name.toLowerCase(Locale.ROOT) == name.toLowerCase(Locale.ROOT)) + } + changes.foldLeft(constraints) { (constraints, change) => change match { case add: AddCheckConstraint => val newConstraint = add.getConstraint - val existingConstraint = - constraints.find( - _.name.toLowerCase(Locale.ROOT) == newConstraint.name().toLowerCase(Locale.ROOT)) + val existingConstraint = findExistingConstraint(newConstraint.name) if (existingConstraint.isDefined) { throw new AnalysisException( errorClass = "CONSTRAINT_ALREADY_EXISTS", @@ -218,7 +221,7 @@ private[sql] object CatalogV2Util { constraints :+ newConstraint case drop: DropConstraint => - val existingConstraint = constraints.find(_.name == drop.getName) + val existingConstraint = findExistingConstraint(drop.getName) if (existingConstraint.isEmpty && !drop.isIfExists) { throw new AnalysisException( errorClass = "CONSTRAINT_DOES_NOT_EXIST", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index 39bea8a4a3e4..bf90823d87af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.connector.catalog.Constraint.Check import org.apache.spark.sql.execution.command.DDLCommandTestUtils class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" + test("Nondeterministic expression") { withTable("t") { sql("create table t(i double) using parquet") @@ -132,6 +134,4 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma } } } - - override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala new file mode 100644 index 000000000000..74633f4f8e1a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.connector.catalog.Constraint.Check +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + +class DropConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. DROP CONSTRAINT" + + test("Drop a non-exist constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + val e = intercept[AnalysisException] { + sql(s"ALTER TABLE $t DROP CONSTRAINT c1") + } + checkError( + exception = e, + condition = "CONSTRAINT_DOES_NOT_EXIST", + sqlState = "42704", + parameters = Map("constraintName" -> "c1", "tableName" -> "test_catalog.ns.tbl") + ) + } + } + + test("Drop a non-exist constraint if exists") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + sql(s"ALTER TABLE $t DROP CONSTRAINT IF EXISTS c1") + } + } + + test("Drop a constraint on a non-exist table") { + Seq("", "IF EXISTS").foreach { ifExists => + val e = intercept[AnalysisException] { + sql(s"ALTER TABLE test_catalog.ns.tbl DROP CONSTRAINT $ifExists c1") + } + checkError( + exception = e, + condition = "TABLE_OR_VIEW_NOT_FOUND", + sqlState = "42P01", + parameters = Map("relationName" -> "`test_catalog`.`ns`.`tbl`"), + context = ExpectedContext( + fragment = "test_catalog.ns.tbl", + start = 12, + stop = 30 + ) + ) + } + } + + test("Drop existing constraints") { + Seq("", "IF EXISTS").foreach { ifExists => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + sql(s"ALTER TABLE $t ADD CONSTRAINT c1 CHECK (id > 0)") + sql(s"ALTER TABLE $t ADD CONSTRAINT c2 CHECK (len(data) > 0)") + sql(s"ALTER TABLE $t DROP CONSTRAINT $ifExists c1") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head.asInstanceOf[Check] + assert(constraint.name() == "c2") + + sql(s"ALTER TABLE $t DROP CONSTRAINT $ifExists c2") + val table2 = loadTable(catalog, "ns", "tbl") + assert(table2.constraints.length == 0) + } + } + } + + test("Drop constraint is case insensitive") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + sql(s"ALTER TABLE $t ADD CONSTRAINT abc CHECK (id > 0)") + sql(s"ALTER TABLE $t DROP CONSTRAINT aBC") + } + } +} From 8264d5a14eefaeeec4b1fb1cb66600558f93ff80 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 24 Feb 2025 14:34:23 -0800 Subject: [PATCH 07/65] refactor parse changes --- .../resources/error/error-conditions.json | 5 ++++ .../sql/catalyst/parser/SqlBaseParser.g4 | 10 +++++-- .../spark/sql/errors/QueryParsingErrors.scala | 7 +++++ .../sql/catalyst/parser/AstBuilder.scala | 30 +++++++++++++++++-- 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index c1abaddc1f39..d55b89f742f4 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5643,6 +5643,11 @@ "Attach a comment to the namespace ." ] }, + "CONSTRAINT_TYPE" : { + "message" : [ + "Constraint ." + ] + }, "CONTINUE_EXCEPTION_HANDLER" : { "message" : [ "CONTINUE exception handler is not supported. Use EXIT handler." diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 86d9d7b2d0f2..729f3bc7cd55 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -261,8 +261,7 @@ statement | ALTER TABLE identifierReference (clusterBySpec | CLUSTER BY NONE) #alterClusterBy | ALTER TABLE identifierReference collationSpec #alterTableCollation - | ALTER TABLE identifierReference ADD CONSTRAINT name=identifier - constraint #addTableConstraint + | ALTER TABLE identifierReference ADD constraintSpec #addTableConstraint | ALTER TABLE identifierReference DROP CONSTRAINT (IF EXISTS)? name=identifier (RESTRICT | CASCADE)? #dropTableConstraint @@ -563,6 +562,7 @@ createTableClauses locationSpec | commentSpec | collationSpec | + constraintSpec | (TBLPROPERTIES tableProps=propertyList))* ; @@ -1521,7 +1521,11 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; -constraint +constraintSpec + : CONSTRAINT constraintName=errorCapturingIdentifier constraintExpression + ; + +constraintExpression : CHECK '(' booleanExpression ')' #checkConstraint ; diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 0bd9f3801498..f13899d6e40a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -791,4 +791,11 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { def clusterByWithBucketing(ctx: ParserRuleContext): Throwable = { new ParseException(errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", ctx) } + + def constraintNotSupportedError(ctx: ParserRuleContext, constraint: String): Throwable = { + new ParseException( + errorClass = "UNSUPPORTED_FEATURE.CONSTRAINT_TYPE", + messageParameters = Map("constraint" -> constraint), + ctx) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 337729af8d7b..bc54cbd1ee41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5237,15 +5237,39 @@ class AstBuilder extends DataTypeAstBuilder AlterTableCollation(table, visitCollationSpec(ctx.collationSpec())) } + override def visitConstraintSpec(ctx: ConstraintSpecContext): Expression = { + ctx.constraintExpression() match { + case c: CheckConstraintContext => expression(c.booleanExpression()) + case other => + throw QueryParsingErrors.constraintNotSupportedError(ctx, other.getText) + } + } + + /** + * Parse a [[AddConstraint]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 CONSTRAINT constraint_name CHECK (a > 0) + * }}} + */ override def visitAddTableConstraint(ctx: AddTableConstraintContext): LogicalPlan = withOrigin(ctx) { val table = createUnresolvedTable( ctx.identifierReference, "ALTER TABLE ... ADD CONSTRAINT") - val constraint = - expression(ctx.constraint().asInstanceOf[CheckConstraintContext].booleanExpression()) - AddConstraint(table, ctx.name.getText, constraint) + val constraintExpression = visitConstraintSpec(ctx.constraintSpec()) + AddConstraint(table, ctx.constraintSpec().constraintName.getText, constraintExpression) } + + /** + * Parse a [[DropConstraint]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 DROP CONSTRAINT constraint_name + * }}} + */ override def visitDropTableConstraint(ctx: DropTableConstraintContext): LogicalPlan = withOrigin(ctx) { val table = createUnresolvedTable( From 5c57ce271cd99fb91a86e85eaf5659afa932ac06 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 24 Feb 2025 17:00:20 -0800 Subject: [PATCH 08/65] fix visitCreateTableClauses --- .../spark/sql/catalyst/parser/AstBuilder.scala | 12 ++++++++---- .../apache/spark/sql/execution/SparkSqlParser.scala | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bc54cbd1ee41..b0d1f948ef10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4138,7 +4138,8 @@ class AstBuilder extends DataTypeAstBuilder */ type TableClauses = ( Seq[Transform], Seq[ColumnDefinition], Option[BucketSpec], Map[String, String], OptionList, - Option[String], Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec]) + Option[String], Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec], + Seq[Expression]) /** * Validate a create table statement and return the [[TableIdentifier]]. @@ -4641,8 +4642,10 @@ class AstBuilder extends DataTypeAstBuilder } } + val constraints = ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq + (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, - collation, serdeInfo, clusterBySpec) + collation, serdeInfo, clusterBySpec, constraints) } protected def getSerdeInfo( @@ -4717,7 +4720,8 @@ class AstBuilder extends DataTypeAstBuilder val columns = Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse(Nil) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) val (partTransforms, partCols, bucketSpec, properties, options, location, comment, - collation, serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) + collation, serdeInfo, clusterBySpec, constraints) = + visitCreateTableClauses(ctx.createTableClauses()) if (provider.isDefined && serdeInfo.isDefined) { invalidStatement(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) @@ -4795,7 +4799,7 @@ class AstBuilder extends DataTypeAstBuilder override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) { val orCreate = ctx.replaceTableHeader().CREATE() != null val (partTransforms, partCols, bucketSpec, properties, options, location, comment, collation, - serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) + serdeInfo, clusterBySpec, constraints) = visitCreateTableClauses(ctx.createTableClauses()) val columns = Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse(Nil) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8859b7b421b3..efb1602a2b69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -361,7 +361,7 @@ class SparkSqlAstBuilder extends AstBuilder { invalidStatement("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) } - val (_, _, _, _, options, location, _, _, _, _) = + val (_, _, _, _, options, location, _, _, _, _, _) = visitCreateTableClauses(ctx.createTableClauses()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText).getOrElse( throw QueryParsingErrors.createTempTableNotSpecifyProviderError(ctx)) From c3a7c415145d4395a9d153983963a6cafce86f32 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 24 Feb 2025 17:40:28 -0800 Subject: [PATCH 09/65] rename as AddCheckConstraint --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../sql/catalyst/plans/logical/v2AlterTableCommands.scala | 2 +- .../execution/command/AlterTableAddConstraintParseSuite.scala | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c94f5476d30b..fba6c0a5588b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1191,7 +1191,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString case _ => } - case addConstraint @ AddConstraint(table: ResolvedTable, _, constraintExpr) => + case addConstraint @ AddCheckConstraint(table: ResolvedTable, _, constraintExpr) => if (!constraintExpr.resolved) { constraintExpr.failAnalysis( errorClass = "INVALID_CHECK_CONSTRAINT.UNRESOLVED", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b0d1f948ef10..dcc77d842fae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5250,7 +5250,7 @@ class AstBuilder extends DataTypeAstBuilder } /** - * Parse a [[AddConstraint]] command. + * Parse a [[AddCheckConstraint]] command. * * For example: * {{{ @@ -5262,7 +5262,7 @@ class AstBuilder extends DataTypeAstBuilder val table = createUnresolvedTable( ctx.identifierReference, "ALTER TABLE ... ADD CONSTRAINT") val constraintExpression = visitConstraintSpec(ctx.constraintSpec()) - AddConstraint(table, ctx.constraintSpec().constraintName.getText, constraintExpression) + AddCheckConstraint(table, ctx.constraintSpec().constraintName.getText, constraintExpression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index 10176ad4b74d..aa3770e2bde4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -292,7 +292,7 @@ case class AlterTableCollation( /** * The logical plan of the ALTER TABLE ... ADD CONSTRAINT command. */ -case class AddConstraint( +case class AddCheckConstraint( table: LogicalPlan, name: String, constraintExpr: Expression) extends AlterTableCommand { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index 73f99235f6e0..32082489504d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{GreaterThan, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.AddConstraint +import org.apache.spark.sql.catalyst.plans.logical.AddCheckConstraint import org.apache.spark.sql.test.SharedSparkSession class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSession { @@ -31,7 +31,7 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) |""".stripMargin val parsed = parsePlan(sql) - val expected = AddConstraint( + val expected = AddCheckConstraint( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... ADD CONSTRAINT"), From 90e14040dd6f8a58fd6f9a71491b8f34bbd76a96 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Feb 2025 23:55:42 -0800 Subject: [PATCH 10/65] save for now --- .../spark/sql/catalyst/analysis/ResolveTableSpec.scala | 3 ++- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../spark/sql/catalyst/plans/logical/v2Commands.scala | 9 ++++++--- .../scala/org/apache/spark/sql/classic/Catalog.scala | 3 ++- .../org/apache/spark/sql/classic/DataFrameWriter.scala | 9 ++++++--- .../org/apache/spark/sql/classic/DataFrameWriterV2.scala | 6 ++++-- .../org/apache/spark/sql/classic/DataStreamWriter.scala | 3 ++- 7 files changed, 24 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index 05158fbee3de..38d924abb91e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -94,7 +94,8 @@ object ResolveTableSpec extends Rule[LogicalPlan] { comment = u.comment, collation = u.collation, serde = u.serde, - external = u.external) + external = u.external, + constraints = Seq.empty) withNewSpec(newTableSpec) case _ => input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index dcc77d842fae..8a60891312ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4739,7 +4739,7 @@ class AstBuilder extends DataTypeAstBuilder clusterBySpec.map(_.asTransform) val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external) + collation, serdeInfo, external, constraints) Option(ctx.query).map(plan) match { case Some(_) if columns.nonEmpty => @@ -4813,7 +4813,7 @@ class AstBuilder extends DataTypeAstBuilder clusterBySpec.map(_.asTransform) val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external = false) + collation, serdeInfo, external = false, constraints) Option(ctx.query).map(plan) match { case Some(_) if columns.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 1056a30c5f75..12f52968491a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1520,7 +1520,8 @@ case class UnresolvedTableSpec( comment: Option[String], collation: Option[String], serde: Option[SerdeInfo], - external: Boolean) extends UnaryExpression with Unevaluable with TableSpecBase { + external: Boolean, + constraints: Seq[Expression]) extends UnaryExpression with Unevaluable with TableSpecBase { override def dataType: DataType = throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113") @@ -1566,9 +1567,11 @@ case class TableSpec( comment: Option[String], collation: Option[String], serde: Option[SerdeInfo], - external: Boolean) extends TableSpecBase { + external: Boolean, + constraints: Seq[Constraint]) extends TableSpecBase { def withNewLocation(newLocation: Option[String]): TableSpec = { - TableSpec(properties, provider, options, newLocation, comment, collation, serde, external) + TableSpec(properties, provider, options, newLocation, + comment, collation, serde, external, constraints) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index aff65496b763..3b4f6475a6bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -688,7 +688,8 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { comment = { if (description.isEmpty) None else Some(description) }, collation = None, serde = None, - external = tableType == CatalogTableType.EXTERNAL) + external = tableType == CatalogTableType.EXTERNAL, + constraints = Seq.empty) val plan = CreateTable( name = UnresolvedIdentifier(ident), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala index b423c89fff3d..501b4985128d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala @@ -213,7 +213,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram comment = extraOptions.get(TableCatalog.PROP_COMMENT), collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, - external = false) + external = false, + constraints = Seq.empty) runCommand(df.sparkSession) { CreateTableAsSelect( UnresolvedIdentifier( @@ -478,7 +479,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram comment = extraOptions.get(TableCatalog.PROP_COMMENT), collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, - external = false) + external = false, + constraints = Seq.empty) ReplaceTableAsSelect( UnresolvedIdentifier(nameParts), partitioningAsV2, @@ -499,7 +501,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram comment = extraOptions.get(TableCatalog.PROP_COMMENT), collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, - external = false) + external = false, + constraints = Seq.empty) CreateTableAsSelect( UnresolvedIdentifier(nameParts), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala index e4efee93d2a0..01b3619f1236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala @@ -154,7 +154,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) comment = None, collation = None, serde = None, - external = false) + external = false, + constraints = Seq.empty) runCommand( CreateTableAsSelect( UnresolvedIdentifier(tableName), @@ -220,7 +221,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) comment = None, collation = None, serde = None, - external = false) + external = false, + constraints = Seq.empty) runCommand(ReplaceTableAsSelect( UnresolvedIdentifier(tableName), partitioning.getOrElse(Seq.empty) ++ clustering, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala index 96e875557754..471c5feadaab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala @@ -175,7 +175,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D None, None, None, - external = false) + external = false, + constraints = Seq.empty) val cmd = CreateTable( UnresolvedIdentifier(originalMultipartIdentifier), ds.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), From 6381a2414b9a16599f141aba6ec70945034c5735 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 3 Mar 2025 17:26:56 -0800 Subject: [PATCH 11/65] refactor check constraint --- .../sql/connector/catalog/Constraint.java | 15 ++++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 9 +--- .../sql/catalyst/parser/AstBuilder.scala | 24 ++++----- .../plans/logical/v2AlterTableCommands.scala | 3 +- .../catalyst/plans/logical/v2Commands.scala | 5 +- .../AlterTableAddConstraintParseSuite.scala | 1 + .../command/v2/CheckConstraintSuite.scala | 54 +++++++++---------- 7 files changed, 55 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java index 01692d5f9abd..d6b15a452a9f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java @@ -25,17 +25,20 @@ public interface Constraint { boolean rely(); // indicates whether the constraint is believed to be true boolean enforced(); // indicates whether the constraint must be enforced - static Constraint check(String name, Predicate predicate) { - return new Check(name, predicate); + static Constraint check(String name, String sql, Predicate predicate) { + return new Check(name, sql, predicate); } final class Check implements Constraint { private final String name; + private final String sql; private final Predicate predicate; - private Check(String name, Predicate predicate) { + private Check(String name, String sql, Predicate predicate) { this.name = name; + this.sql = sql; this.predicate = predicate; } + @Override public String name() { return name; } @@ -46,7 +49,7 @@ private Check(String name, Predicate predicate) { @Override public String toDDL() { - return "CHECK (" + predicate.toString() + ")"; + return "CHECK (" + sql + ")"; } @Override public boolean rely() { @@ -56,5 +59,9 @@ public String toDDL() { @Override public boolean enforced() { return true; } + + public String sql() { return sql; } + + public Predicate predicate() { return predicate; } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fba6c0a5588b..0a5ff2006147 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1191,7 +1191,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString case _ => } - case addConstraint @ AddCheckConstraint(table: ResolvedTable, _, constraintExpr) => + case addConstraint @ AddCheckConstraint(table: ResolvedTable, _, _, constraintExpr) => if (!constraintExpr.resolved) { constraintExpr.failAnalysis( errorClass = "INVALID_CHECK_CONSTRAINT.UNRESOLVED", @@ -1205,13 +1205,6 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString messageParameters = Map.empty ) } - - if (addConstraint.predicate.isEmpty) { - constraintExpr.failAnalysis( - errorClass = "INVALID_CHECK_CONSTRAINT.INVALID_V2_PREDICATE", - messageParameters = Map.empty - ) - } case _ => } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8a60891312ce..ac7552ff94a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4642,10 +4642,10 @@ class AstBuilder extends DataTypeAstBuilder } } - val constraints = ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq + // val constraints = ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, - collation, serdeInfo, clusterBySpec, constraints) + collation, serdeInfo, clusterBySpec, Seq.empty) } protected def getSerdeInfo( @@ -5241,14 +5241,6 @@ class AstBuilder extends DataTypeAstBuilder AlterTableCollation(table, visitCollationSpec(ctx.collationSpec())) } - override def visitConstraintSpec(ctx: ConstraintSpecContext): Expression = { - ctx.constraintExpression() match { - case c: CheckConstraintContext => expression(c.booleanExpression()) - case other => - throw QueryParsingErrors.constraintNotSupportedError(ctx, other.getText) - } - } - /** * Parse a [[AddCheckConstraint]] command. * @@ -5261,8 +5253,16 @@ class AstBuilder extends DataTypeAstBuilder withOrigin(ctx) { val table = createUnresolvedTable( ctx.identifierReference, "ALTER TABLE ... ADD CONSTRAINT") - val constraintExpression = visitConstraintSpec(ctx.constraintSpec()) - AddCheckConstraint(table, ctx.constraintSpec().constraintName.getText, constraintExpression) + ctx.constraintSpec.constraintExpression() match { + case c: CheckConstraintContext => + AddCheckConstraint( + table, + ctx.constraintSpec().constraintName.getText, + c.booleanExpression().getText, + expression(c.booleanExpression())) + case other => + throw QueryParsingErrors.constraintNotSupportedError(ctx, other.getText) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index aa3770e2bde4..8b713b059f4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -295,12 +295,13 @@ case class AlterTableCollation( case class AddCheckConstraint( table: LogicalPlan, name: String, + constraintText: String, constraintExpr: Expression) extends AlterTableCommand { lazy val predicate = new V2ExpressionBuilder(constraintExpr, true).buildPredicate() override def changes: Seq[TableChange] = { - val constraint = Constraint.check(name, predicate.get) + val constraint = Constraint.check(name, constraintText, predicate.orNull) Seq(TableChange.addCheckConstraint(constraint, constraint.enforced())) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 12f52968491a..0a90f7b72ee3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1521,7 +1521,8 @@ case class UnresolvedTableSpec( collation: Option[String], serde: Option[SerdeInfo], external: Boolean, - constraints: Seq[Expression]) extends UnaryExpression with Unevaluable with TableSpecBase { + constraints: Seq[Expression] = Seq.empty) + extends UnaryExpression with Unevaluable with TableSpecBase { override def dataType: DataType = throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113") @@ -1568,7 +1569,7 @@ case class TableSpec( collation: Option[String], serde: Option[SerdeInfo], external: Boolean, - constraints: Seq[Constraint]) extends TableSpecBase { + constraints: Seq[Constraint] = Seq.empty) extends TableSpecBase { def withNewLocation(newLocation: Option[String]): TableSpec = { TableSpec(properties, provider, options, newLocation, comment, collation, serde, external, constraints) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index 32082489504d..6ca1715b7f12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -36,6 +36,7 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes Seq("a", "b", "c"), "ALTER TABLE ... ADD CONSTRAINT"), "c1", + "d > 0", GreaterThan(UnresolvedAttribute("d"), Literal(0))) comparePlans(parsed, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index bf90823d87af..e568e19b72ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.connector.catalog.Constraint.Check +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.command.DDLCommandTestUtils class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { @@ -73,27 +74,26 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma } } - test("Can't convert expression to V2 predicate") { - withTable("t") { - sql("create table t(i string) using parquet") - val query = - """ - |ALTER TABLE t ADD CONSTRAINT c1 CHECK (from_json(i, 'a INT').a > 1) - |""".stripMargin - val error = intercept[AnalysisException] { - sql(query) - } - checkError( - exception = error, - condition = "INVALID_CHECK_CONSTRAINT.INVALID_V2_PREDICATE", - sqlState = "42621", - parameters = Map.empty, - context = ExpectedContext( - fragment = "from_json(i, 'a INT').a > 1", - start = 40, - stop = 66 - ) - ) + private def getCheckConstraint(table: Table): Check = { + assert(table.constraints.length == 1) + assert(table.constraints.head.isInstanceOf[Check]) + table.constraints.head.asInstanceOf[Check] + val constraint = table.constraints.head.asInstanceOf[Check] + assert(constraint.rely()) + assert(constraint.enforced()) + constraint + } + + test("Predicate should be null if it can't be converted to V2 predicate") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, j string) $defaultUsing") + sql(s"ALTER TABLE $t ADD CONSTRAINT c1 CHECK (from_json(j, 'a INT').a > 1)") + val table = loadTable(catalog, "ns", "tbl") + val constraint = getCheckConstraint(table) + assert(constraint.name() == "c1") + assert(constraint.toDDL == "CHECK (from_json(j,'a INT').a>1)") + assert(constraint.sql() == "from_json(j,'a INT').a>1") + assert(constraint.predicate() == null) } } @@ -104,13 +104,9 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma sql(s"ALTER TABLE $t ADD CONSTRAINT c1 CHECK (id > 0)") val table = loadTable(catalog, "ns", "tbl") - assert(table.constraints.length == 1) - assert(table.constraints.head.isInstanceOf[Check]) - val constraint = table.constraints.head.asInstanceOf[Check] + val constraint = getCheckConstraint(table) assert(constraint.name() == "c1") - assert(constraint.rely()) - assert(constraint.enforced()) - assert(constraint.toDDL == "CHECK (id > 0)") + assert(constraint.toDDL == "CHECK (id>0)") } } @@ -123,13 +119,13 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma // Constraint names are case-insensitive Seq("abc", "ABC").foreach { name => val error = intercept[AnalysisException] { - sql(s"ALTER TABLE $t ADD CONSTRAINT $name CHECK (id > 0)") + sql(s"ALTER TABLE $t ADD CONSTRAINT $name CHECK (id>0)") } checkError( exception = error, condition = "CONSTRAINT_ALREADY_EXISTS", sqlState = "42710", - parameters = Map("constraintName" -> "abc", "oldConstraint" -> "CHECK (id > 0)") + parameters = Map("constraintName" -> "abc", "oldConstraint" -> "CHECK (id>0)") ) } } From 5eaa0694c1931fbd315631f7eb552f44a6645ed7 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 3 Mar 2025 22:26:49 -0800 Subject: [PATCH 12/65] rename AddConstraint --- .../spark/sql/connector/catalog/TableChange.java | 10 +++++----- .../catalyst/plans/logical/v2AlterTableCommands.scala | 2 +- .../spark/sql/connector/catalog/CatalogV2Util.scala | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index beaa4e4afaa8..3fa3720b001e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -263,8 +263,8 @@ static TableChange clusterBy(NamedReference[] clusteringColumns) { /** * Create a TableChange for adding a new Table Constraint */ - static TableChange addCheckConstraint(Constraint constraint, Boolean validate) { - return new AddCheckConstraint(constraint, validate); + static TableChange addConstraint(Constraint constraint, Boolean validate) { + return new AddConstraint(constraint, validate); } /** @@ -807,11 +807,11 @@ public int hashCode() { } /** A TableChange to alter table and add a constraint. */ - final class AddCheckConstraint implements TableChange { + final class AddConstraint implements TableChange { private final Constraint constraint; private final boolean validate; - private AddCheckConstraint(Constraint constraint, boolean validate) { + private AddConstraint(Constraint constraint, boolean validate) { this.constraint = constraint; this.validate = validate; } @@ -828,7 +828,7 @@ public boolean isValidate() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - AddCheckConstraint that = (AddCheckConstraint) o; + AddConstraint that = (AddConstraint) o; return constraint.equals(that.constraint) && validate == that.validate; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index 8b713b059f4e..bf64e959f2c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -302,7 +302,7 @@ case class AddCheckConstraint( override def changes: Seq[TableChange] = { val constraint = Constraint.check(name, constraintText, predicate.orNull) - Seq(TableChange.addCheckConstraint(constraint, constraint.enforced())) + Seq(TableChange.addConstraint(constraint, constraint.enforced())) } protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index fdf5693ae34d..0997aed99e8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -208,7 +208,7 @@ private[sql] object CatalogV2Util { changes.foldLeft(constraints) { (constraints, change) => change match { - case add: AddCheckConstraint => + case add: AddConstraint => val newConstraint = add.getConstraint val existingConstraint = findExistingConstraint(newConstraint.name) if (existingConstraint.isDefined) { From a63c1d97a1cdf9ee612cd61f6cc1f5b0d07bea46 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 4 Mar 2025 13:38:40 -0800 Subject: [PATCH 13/65] introduce CheckConstraint expr --- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../catalyst/expressions/constraints.scala | 46 +++++++++++++++++++ .../sql/catalyst/parser/AstBuilder.scala | 31 ++++++++----- .../plans/logical/v2AlterTableCommands.scala | 15 ++---- .../AlterTableAddConstraintParseSuite.scala | 9 ++-- 5 files changed, 76 insertions(+), 27 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 0a5ff2006147..a9bbf25dd5c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1191,7 +1191,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString case _ => } - case addConstraint @ AddCheckConstraint(table: ResolvedTable, _, _, constraintExpr) => + case addConstraint @ AddCheckConstraint(table: ResolvedTable, constraintExpr) => if (!constraintExpr.resolved) { constraintExpr.failAnalysis( errorClass = "INVALID_CHECK_CONSTRAINT.UNRESOLVED", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala new file mode 100644 index 000000000000..a7b19e8e7028 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder +import org.apache.spark.sql.connector.catalog.Constraint +import org.apache.spark.sql.types.{DataType, StringType} + +trait ConstraintExpression extends Expression with Unevaluable { + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + def asConstraint: Constraint +} + +case class CheckConstraint( + name: String, + override val sql: String, + child: Expression) extends ConstraintExpression + with UnaryLike[Expression] { + + def asConstraint: Constraint = { + val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull + Constraint.check(name, sql, predicate) + } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ac7552ff94a3..4d3da1c67759 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4139,7 +4139,7 @@ class AstBuilder extends DataTypeAstBuilder type TableClauses = ( Seq[Transform], Seq[ColumnDefinition], Option[BucketSpec], Map[String, String], OptionList, Option[String], Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec], - Seq[Expression]) + Seq[ConstraintExpression]) /** * Validate a create table statement and return the [[TableIdentifier]]. @@ -4642,10 +4642,10 @@ class AstBuilder extends DataTypeAstBuilder } } - // val constraints = ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq + val constraints = ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, - collation, serdeInfo, clusterBySpec, Seq.empty) + collation, serdeInfo, clusterBySpec, constraints) } protected def getSerdeInfo( @@ -5241,6 +5241,19 @@ class AstBuilder extends DataTypeAstBuilder AlterTableCollation(table, visitCollationSpec(ctx.collationSpec())) } + override def visitConstraintSpec(ctx: ConstraintSpecContext): ConstraintExpression = { + ctx.constraintExpression() match { + case c: CheckConstraintContext => + CheckConstraint( + name = ctx.constraintName.getText, + sql = c.booleanExpression().getText, + child = expression(c.booleanExpression()) + ) + case other => + throw QueryParsingErrors.constraintNotSupportedError(ctx, other.getText) + } + } + /** * Parse a [[AddCheckConstraint]] command. * @@ -5253,15 +5266,9 @@ class AstBuilder extends DataTypeAstBuilder withOrigin(ctx) { val table = createUnresolvedTable( ctx.identifierReference, "ALTER TABLE ... ADD CONSTRAINT") - ctx.constraintSpec.constraintExpression() match { - case c: CheckConstraintContext => - AddCheckConstraint( - table, - ctx.constraintSpec().constraintName.getText, - c.booleanExpression().getText, - expression(c.booleanExpression())) - case other => - throw QueryParsingErrors.constraintNotSupportedError(ctx, other.getText) + visitConstraintSpec(ctx.constraintSpec) match { + case c: CheckConstraint => + AddCheckConstraint(table, c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index bf64e959f2c5..4a34c1654876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, UnresolvedException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ClusterBySpec -import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} -import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, TypeUtils, V2ExpressionBuilder} -import org.apache.spark.sql.connector.catalog.{Constraint, TableCatalog, TableChange} +import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, Expression, Unevaluable} +import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, TypeUtils} +import org.apache.spark.sql.connector.catalog.{TableCatalog, TableChange} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.DataType import org.apache.spark.util.ArrayImplicits._ @@ -294,14 +294,9 @@ case class AlterTableCollation( */ case class AddCheckConstraint( table: LogicalPlan, - name: String, - constraintText: String, - constraintExpr: Expression) extends AlterTableCommand { - - lazy val predicate = new V2ExpressionBuilder(constraintExpr, true).buildPredicate() - + check: CheckConstraint) extends AlterTableCommand { override def changes: Seq[TableChange] = { - val constraint = Constraint.check(name, constraintText, predicate.orNull) + val constraint = check.asConstraint Seq(TableChange.addConstraint(constraint, constraint.enforced())) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index 6ca1715b7f12..23bed2a1b827 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedTable} -import org.apache.spark.sql.catalyst.expressions.{GreaterThan, Literal} +import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, GreaterThan, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.AddCheckConstraint @@ -35,9 +35,10 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... ADD CONSTRAINT"), - "c1", - "d > 0", - GreaterThan(UnresolvedAttribute("d"), Literal(0))) + CheckConstraint( + "c1", + "d > 0", + GreaterThan(UnresolvedAttribute("d"), Literal(0)))) comparePlans(parsed, expected) } From 0063155ad5837b0a93f7a1180e5090125bbf6eb0 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 4 Mar 2025 15:01:30 -0800 Subject: [PATCH 14/65] introduce Constraints expression --- .../catalyst/analysis/ResolveTableSpec.scala | 4 +-- .../catalyst/expressions/constraints.scala | 25 +++++++++++++++++++ .../sql/catalyst/parser/AstBuilder.scala | 4 +-- .../catalyst/plans/logical/v2Commands.scala | 24 +++++++++++------- ...eateTablePartitioningValidationSuite.scala | 5 ++-- .../sql/catalyst/parser/DDLParserSuite.scala | 20 ++++++++------- .../apache/spark/sql/classic/Catalog.scala | 4 +-- .../spark/sql/classic/DataFrameWriter.scala | 8 +++--- .../spark/sql/classic/DataFrameWriterV2.scala | 6 ++--- .../spark/sql/classic/DataStreamWriter.scala | 3 ++- .../V2CommandsCaseSensitivitySuite.scala | 9 ++++--- 11 files changed, 74 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index 38d924abb91e..7c6bfd6e4a3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -61,7 +61,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] { input: LogicalPlan, tableSpec: TableSpecBase, withNewSpec: TableSpecBase => LogicalPlan): LogicalPlan = tableSpec match { - case u: UnresolvedTableSpec if u.optionExpression.resolved => + case u: UnresolvedTableSpec if u.optionExpression.resolved && u.constraints.childrenResolved => val newOptions: Seq[(String, String)] = u.optionExpression.options.map { case (key: String, null) => (key, null) @@ -95,7 +95,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] { collation = u.collation, serde = u.serde, external = u.external, - constraints = Seq.empty) + constraints = u.constraints.asConstraintList) withNewSpec(newTableSpec) case _ => input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index a7b19e8e7028..128ecce247a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.Constraint @@ -44,3 +45,27 @@ case class CheckConstraint( copy(child = newChild) } +/* + * A list of constraints that are applied to a table. + */ +case class Constraints(children: Seq[Expression]) extends Expression with Unevaluable { + + assert(children.forall(_.isInstanceOf[ConstraintExpression])) + + override def nullable: Boolean = true + + override def dataType: DataType = + throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113") + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + copy(children = newChildren) + } + + def asConstraintList: Seq[Constraint] = + children.map(_.asInstanceOf[ConstraintExpression].asConstraint) +} + +object Constraints { + val empty: Constraints = Constraints(Nil) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4d3da1c67759..e0965036cc02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4139,7 +4139,7 @@ class AstBuilder extends DataTypeAstBuilder type TableClauses = ( Seq[Transform], Seq[ColumnDefinition], Option[BucketSpec], Map[String, String], OptionList, Option[String], Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec], - Seq[ConstraintExpression]) + Constraints) /** * Validate a create table statement and return the [[TableIdentifier]]. @@ -4642,7 +4642,7 @@ class AstBuilder extends DataTypeAstBuilder } } - val constraints = ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq + val constraints = Constraints(ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq) (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, collation, serdeInfo, clusterBySpec, constraints) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 0a90f7b72ee3..a9fccb87523b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, UnaryExpression, Unevaluable, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, ReplaceDataProjections, RowDeltaUtils, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} @@ -1521,20 +1521,26 @@ case class UnresolvedTableSpec( collation: Option[String], serde: Option[SerdeInfo], external: Boolean, - constraints: Seq[Expression] = Seq.empty) - extends UnaryExpression with Unevaluable with TableSpecBase { + constraints: Constraints) + extends BinaryExpression with Unevaluable with TableSpecBase { override def dataType: DataType = throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113") - override def child: Expression = optionExpression - - override protected def withNewChildInternal(newChild: Expression): Expression = - this.copy(optionExpression = newChild.asInstanceOf[OptionList]) - override def simpleString(maxFields: Int): String = { this.copy(properties = Utils.redact(properties).toMap).toString } + + override def left: Expression = optionExpression + + override def right: Expression = constraints + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = + copy(optionExpression = newLeft.asInstanceOf[OptionList], + constraints = newRight.asInstanceOf[Constraints]) + + override def nullable: Boolean = true } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index 133670d5fcce..bf94758267fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Constraints} import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode, OptionList, UnresolvedTableSpec} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, Table, TableCapability, TableCatalog} @@ -30,7 +30,8 @@ import org.apache.spark.util.ArrayImplicits._ class CreateTablePartitioningValidationSuite extends AnalysisTest { val tableSpec = - UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, None, false) + UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, None, false, + Constraints.empty) test("CreateTableAsSelect: fail missing top-level column") { val plan = CreateTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index c8d2de9c6b8d..6d4df1432d87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.SparkThrowable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{EqualTo, Hex, Literal} +import org.apache.spark.sql.catalyst.expressions.{Constraints, EqualTo, Hex, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.IdentityColumnSpec import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{after, first} @@ -2705,7 +2705,7 @@ class DDLParserSuite extends AnalysisTest { val createTableResult = CreateTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithDefaultValue, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false), false) + OptionList(Seq.empty), None, None, None, None, false, Constraints.empty), false) // Parse the CREATE TABLE statement twice, swapping the order of the NOT NULL and DEFAULT // options, to make sure that the parser accepts any ordering of these options. comparePlans(parsePlan( @@ -2718,7 +2718,7 @@ class DDLParserSuite extends AnalysisTest { "b STRING NOT NULL DEFAULT 'abc') USING parquet"), ReplaceTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithDefaultValue, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false), false)) + OptionList(Seq.empty), None, None, None, None, false, Constraints.empty), false)) // These ALTER TABLE statements should parse successfully. comparePlans( parsePlan("ALTER TABLE t1 ADD COLUMN x int NOT NULL DEFAULT 42"), @@ -2881,12 +2881,12 @@ class DDLParserSuite extends AnalysisTest { "CREATE TABLE my_tab(a INT, b INT NOT NULL GENERATED ALWAYS AS (a+1)) USING parquet"), CreateTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithGenerationExpr, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false), false)) + OptionList(Seq.empty), None, None, None, None, false, Constraints.empty), false)) comparePlans(parsePlan( "REPLACE TABLE my_tab(a INT, b INT NOT NULL GENERATED ALWAYS AS (a+1)) USING parquet"), ReplaceTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithGenerationExpr, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false), false)) + OptionList(Seq.empty), None, None, None, None, false, Constraints.empty), false)) // Two generation expressions checkError( exception = parseException("CREATE TABLE my_tab(a INT, " + @@ -2957,7 +2957,8 @@ class DDLParserSuite extends AnalysisTest { None, None, None, - false + false, + Constraints.empty ), false ) @@ -2980,7 +2981,8 @@ class DDLParserSuite extends AnalysisTest { None, None, None, - false + false, + Constraints.empty ), false ) @@ -3273,7 +3275,7 @@ class DDLParserSuite extends AnalysisTest { Seq(ColumnDefinition("c", StringType)), Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), - None, None, Some(collation), None, false), false)) + None, None, Some(collation), None, false, Constraints.empty), false)) } } @@ -3285,7 +3287,7 @@ class DDLParserSuite extends AnalysisTest { Seq(ColumnDefinition("c", StringType)), Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), - None, None, Some(collation), None, false), false)) + None, None, Some(collation), None, false, Constraints.empty), false)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index 3b4f6475a6bc..30118744ca78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Constraints, Expression, Literal} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -689,7 +689,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { collation = None, serde = None, external = tableType == CatalogTableType.EXTERNAL, - constraints = Seq.empty) + constraints = Constraints.empty) val plan = CreateTable( name = UnresolvedIdentifier(ident), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala index 501b4985128d..6733f31ec017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchTableException, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Constraints, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog._ @@ -214,7 +214,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, external = false, - constraints = Seq.empty) + constraints = Constraints.empty) runCommand(df.sparkSession) { CreateTableAsSelect( UnresolvedIdentifier( @@ -480,7 +480,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, external = false, - constraints = Seq.empty) + constraints = Constraints.empty) ReplaceTableAsSelect( UnresolvedIdentifier(nameParts), partitioningAsV2, @@ -502,7 +502,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, external = false, - constraints = Seq.empty) + constraints = Constraints.empty) CreateTableAsSelect( UnresolvedIdentifier(nameParts), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala index 01b3619f1236..fdedb1a50c47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Constraints, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ import org.apache.spark.sql.connector.expressions._ @@ -155,7 +155,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) collation = None, serde = None, external = false, - constraints = Seq.empty) + constraints = Constraints.empty) runCommand( CreateTableAsSelect( UnresolvedIdentifier(tableName), @@ -222,7 +222,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) collation = None, serde = None, external = false, - constraints = Seq.empty) + constraints = Constraints.empty) runCommand(ReplaceTableAsSelect( UnresolvedIdentifier(tableName), partitioning.getOrElse(Seq.empty) ++ clustering, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala index 471c5feadaab..4eff6adb28b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{streaming, Dataset => DS, ForeachWriter} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.Constraints import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -176,7 +177,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D None, None, external = false, - constraints = Seq.empty) + constraints = Constraints.empty) val cmd = CreateTable( UnresolvedIdentifier(originalMultipartIdentifier), ds.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index a1089b4291e9..1f012099bda4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedFieldName, UnresolvedFieldPosition, UnresolvedIdentifier} +import org.apache.spark.sql.catalyst.expressions.Constraints import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumns, AlterColumnSpec, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, OptionList, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect, UnresolvedTableSpec} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes @@ -54,7 +55,7 @@ class V2CommandsCaseSensitivitySuite Seq("ID", "iD").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false) + None, None, None, None, false, Constraints.empty) val plan = CreateTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.identity(ref) :: Nil, @@ -79,7 +80,7 @@ class V2CommandsCaseSensitivitySuite Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false) + None, None, None, None, false, Constraints.empty) val plan = CreateTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, ref) :: Nil, @@ -105,7 +106,7 @@ class V2CommandsCaseSensitivitySuite Seq("ID", "iD").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false) + None, None, None, None, false, Constraints.empty) val plan = ReplaceTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.identity(ref) :: Nil, @@ -130,7 +131,7 @@ class V2CommandsCaseSensitivitySuite Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false) + None, None, None, None, false, Constraints.empty) val plan = ReplaceTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, ref) :: Nil, From af2fa61022c53304be4b172d6073cdfa5d129b3e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 4 Mar 2025 20:55:18 -0800 Subject: [PATCH 15/65] add CreateTableConstraintParseSuite --- .../v1/CreateTableConstraintParseSuite.scala | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala new file mode 100644 index 000000000000..8e681b23abe4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v1 + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedIdentifier} +import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, Constraints, EqualTo, GreaterThan, Literal} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType} + +class CreateTableConstraintParseSuite extends AnalysisTest with SharedSparkSession { + val createTablePrefix = "CREATE TABLE t (a INT, b STRING) USING parquet" + val tableId = UnresolvedIdentifier(Seq("t")) + val columns = Seq( + ColumnDefinition("a", IntegerType), + ColumnDefinition("b", StringType) + ) + + def verifyConstraints(constraintStr: String, constraints: Constraints): Unit = { + val sql = + s""" + |$createTablePrefix + |$constraintStr + |""".stripMargin + + val parsed = parsePlan(sql) + val tableSpec = UnresolvedTableSpec( + Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), + None, None, None, None, false, constraints) + val expected = CreateTable(tableId, columns, Seq.empty, tableSpec, false) + comparePlans(parsed, expected) + } + + test("Create table with one check constraint") { + val constraintStr = "CONSTRAINT c1 CHECK (a > 0)" + val constraint = CheckConstraint("c1", "a>0", GreaterThan(UnresolvedAttribute("a"), Literal(0))) + val constraints = Constraints(Seq(constraint)) + verifyConstraints(constraintStr, constraints) + } + + test("Create table with two check constraints") { + val constraintStr = "CONSTRAINT c1 CHECK (a > 0) CONSTRAINT c2 CHECK (b = 'foo')" + val constraint1 = + CheckConstraint("c1", "a>0", GreaterThan(UnresolvedAttribute("a"), Literal(0))) + val constraint2 = + CheckConstraint("c2", "b='foo'", EqualTo(UnresolvedAttribute("b"), Literal("foo"))) + val constraints = Constraints(Seq(constraint1, constraint2)) + verifyConstraints(constraintStr, constraints) + } +} From 5d86c93c703ccf8fb87ab291e9f823c5855679df Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 4 Mar 2025 21:50:04 -0800 Subject: [PATCH 16/65] add new create table api --- .../sql/connector/catalog/TableCatalog.java | 24 +++++++++++++++++++ .../catalog/InMemoryTableCatalog.scala | 10 +++++--- .../datasources/v2/CreateTableExec.scala | 3 ++- .../datasources/v2/ReplaceTableExec.scala | 3 ++- .../spark/sql/RuntimeNullChecksV2Writes.scala | 14 +++++++---- .../KeyGroupedPartitioningSuite.scala | 2 +- .../WriteDistributionAndOrderingSuite.scala | 8 ++++--- 7 files changed, 50 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index 77dbaa7687b4..d882c14a335f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -230,6 +230,30 @@ default Table createTable( return createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties); } + /** + * Create a table in the catalog. + * + * @param ident a table identifier + * @param columns the columns of the new table. + * @param partitions transforms to use for partitioning data in the table + * @param properties a string map of table properties + * @param constraints constraints for the new table + * @return metadata for the new table. This can be null if getting the metadata for the new table + * is expensive. Spark will call {@link #loadTable(Identifier)} if needed (e.g. CTAS). + * + * @throws TableAlreadyExistsException If a table or view already exists for the identifier + * @throws UnsupportedOperationException If a requested partition transform is not supported + * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional) + */ + default Table createTable( + Identifier ident, + Column[] columns, + Transform[] partitions, + Map properties, + Constraint[] constraints) throws TableAlreadyExistsException, NoSuchNamespaceException { + return createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties); + } + /** * If true, mark all the fields of the query schema as nullable when executing * CREATE/REPLACE TABLE ... AS SELECT ... and creating the table. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index ca1f7737c2c3..2cbbb389a7b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -87,11 +87,13 @@ class BasicInMemoryTableCatalog extends TableCatalog { ident: Identifier, columns: Array[Column], partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: util.Map[String, String], + constraints: Array[Constraint]): Table = { createTable(ident, columns, partitions, properties, Distributions.unspecified(), - Array.empty, None, None) + Array.empty, None, None, constraints) } + // scalastyle:off argcount def createTable( ident: Identifier, columns: Array[Column], @@ -101,8 +103,10 @@ class BasicInMemoryTableCatalog extends TableCatalog { ordering: Array[SortOrder], requiredNumPartitions: Option[Int], advisoryPartitionSize: Option[Long], + constraints: Array[Constraint], distributionStrictlyRequired: Boolean = true, numRowsPerSplit: Int = Int.MaxValue): Table = { + // scalastyle:on argcount val schema = CatalogV2Util.v2ColumnsToStructType(columns) if (tables.containsKey(ident)) { throw new TableAlreadyExistsException(ident.asMultipartIdentifier) @@ -113,7 +117,7 @@ class BasicInMemoryTableCatalog extends TableCatalog { val tableName = s"$name.${ident.quoted}" val table = new InMemoryTable(tableName, schema, partitions, properties, distribution, ordering, requiredNumPartitions, advisoryPartitionSize, distributionStrictlyRequired, - numRowsPerSplit) + numRowsPerSplit, constraints) tables.put(ident, table) namespaces.putIfAbsent(ident.namespace.toList, Map()) table diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala index f55fbafe11dd..25e5292f3672 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala @@ -43,7 +43,8 @@ case class CreateTableExec( override protected def run(): Seq[InternalRow] = { if (!catalog.tableExists(identifier)) { try { - catalog.createTable(identifier, columns, partitioning.toArray, tableProperties.asJava) + catalog.createTable(identifier, columns, partitioning.toArray, tableProperties.asJava, + tableSpec.constraints.toArray) } catch { case _: TableAlreadyExistsException if ignoreIfExists => logWarning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 894a3a10d419..9c0122c4cd31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -48,7 +48,8 @@ case class ReplaceTableExec( } else if (!orCreate) { throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) } - catalog.createTable(ident, columns, partitioning.toArray, tableProperties.asJava) + catalog.createTable(ident, columns, partitioning.toArray, tableProperties.asJava, + tableSpec.constraints.toArray) Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala index b48ff7121c76..4cc6c02333a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.util.Collections import org.apache.spark.{SparkConf, SparkRuntimeException} -import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, Identifier, InMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, Constraint, Identifier, InMemoryTableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} @@ -214,7 +214,8 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS ColumnV2.create("i", IntegerType), ColumnV2.create("arr", ArrayType(structType, containsNull = false))), partitions = Array.empty[Transform], - properties = Collections.emptyMap[String, String]) + properties = Collections.emptyMap[String, String], + constraints = Array.empty[Constraint]) if (byName) { val inputDF = sql("SELECT 1 AS i, null AS arr") @@ -261,7 +262,8 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS ColumnV2.create("i", IntegerType), ColumnV2.create("arr", ArrayType(structType, containsNull = true))), partitions = Array.empty[Transform], - properties = Collections.emptyMap[String, String]) + properties = Collections.emptyMap[String, String], + constraints = Array.empty[Constraint]) if (byName) { val inputDF = sql( @@ -315,7 +317,8 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS ColumnV2.create("i", IntegerType), ColumnV2.create("m", MapType(IntegerType, IntegerType, valueContainsNull = false))), partitions = Array.empty[Transform], - properties = Collections.emptyMap[String, String]) + properties = Collections.emptyMap[String, String], + constraints = Array.empty[Constraint]) if (byName) { val inputDF = sql("SELECT 1 AS i, null AS m") @@ -354,7 +357,8 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS ColumnV2.create("i", IntegerType), ColumnV2.create("m", MapType(structType, structType, valueContainsNull = true))), partitions = Array.empty[Transform], - properties = Collections.emptyMap[String, String]) + properties = Collections.emptyMap[String, String], + constraints = Array.empty[Constraint]) if (byName) { val inputDF = sql("SELECT 1 AS i, map(named_struct('x', 1, 'y', 1), null) AS m") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index c24f52bd9307..412066a1a41a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -238,7 +238,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { catalog: InMemoryTableCatalog = catalog): Unit = { catalog.createTable(Identifier.of(Array("ns"), table), columns, partitions, emptyProps, Distributions.unspecified(), Array.empty, None, None, - numRowsPerSplit = 1) + Array.empty, numRowsPerSplit = 1) } private val customers: String = "customers" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index f62e092138a9..01a4a52189f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -977,7 +977,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) val distribution = Distributions.ordered(ordering) - catalog.createTable(ident, columns, Array.empty, emptyProps, distribution, ordering, None, None) + catalog.createTable(ident, columns, Array.empty, emptyProps, + distribution, ordering, None, None, Array.empty) withTempDir { checkpointDir => val inputData = ContinuousMemoryStream[(Long, String, Date)] @@ -1218,7 +1219,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase // scalastyle:on argcount catalog.createTable(ident, columns, Array.empty, emptyProps, tableDistribution, - tableOrdering, tableNumPartitions, tablePartitionSize, distributionStrictlyRequired) + tableOrdering, tableNumPartitions, tablePartitionSize, Array.empty, + distributionStrictlyRequired) val df = if (!dataSkewed) { spark.createDataFrame(Seq( @@ -1320,7 +1322,7 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase expectAnalysisException: Boolean = false): Unit = { catalog.createTable(ident, columns, Array.empty, emptyProps, tableDistribution, - tableOrdering, tableNumPartitions, tablePartitionSize) + tableOrdering, tableNumPartitions, tablePartitionSize, Array.empty) withTempDir { checkpointDir => val inputData = MemoryStream[(Long, String, Date)] From a8972ec9cc970b0ccb1edc922535d6eeffdfa202 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 4 Mar 2025 22:35:31 -0800 Subject: [PATCH 17/65] add CreateTableConstraintSuite --- .../v2/CreateTableConstraintSuite.scala | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala new file mode 100644 index 000000000000..f942bcc3ff5d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.connector.catalog.Constraint.Check +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + +class CreateTableConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "CREATE TABLE .. CONSTRAINT" + + test("Create table with one check constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql( + s""" + |CREATE TABLE $t (id bigint, data string) $defaultUsing + | CONSTRAINT c1 CHECK (id > 0)""".stripMargin) + val constraints = loadTable(catalog, "ns", "tbl").constraints + assert(constraints.length == 1) + assert(constraints.head.isInstanceOf[Check]) + val constraint = constraints.head.asInstanceOf[Check] + + assert(constraint.name == "c1") + assert(constraint.sql == "id>0") + assert(constraint.predicate().toString() == "id > 0") + } + } +} From ea2385515868b8f3e6085f71f466dc588d55b9b4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 5 Mar 2025 11:59:41 -0800 Subject: [PATCH 18/65] fix CreateTableConstraintSuite --- .../catalyst/analysis/ResolveTableSpec.scala | 2 +- .../v2/CreateTableConstraintSuite.scala | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index 7c6bfd6e4a3d..3dcd76ee1e4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -61,7 +61,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] { input: LogicalPlan, tableSpec: TableSpecBase, withNewSpec: TableSpecBase => LogicalPlan): LogicalPlan = tableSpec match { - case u: UnresolvedTableSpec if u.optionExpression.resolved && u.constraints.childrenResolved => + case u: UnresolvedTableSpec if u.optionExpression.resolved => val newOptions: Seq[(String, String)] = u.optionExpression.options.map { case (key: String, null) => (key, null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala index f942bcc3ff5d..26bc2ca8d9d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala @@ -40,4 +40,29 @@ class CreateTableConstraintSuite extends QueryTest with CommandSuiteBase with DD assert(constraint.predicate().toString() == "id > 0") } } + + test("Create table with two check constraints") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql( + s""" + |CREATE TABLE $t (id bigint, data string) $defaultUsing + | CONSTRAINT c1 CHECK (id > 0) + | CONSTRAINT c2 CHECK (data = 'foo')""".stripMargin) + val constraints = loadTable(catalog, "ns", "tbl").constraints + assert(constraints.length == 2) + assert(constraints.head.isInstanceOf[Check]) + val constraint = constraints.head.asInstanceOf[Check] + + assert(constraint.name == "c1") + assert(constraint.sql == "id>0") + assert(constraint.predicate().toString() == "id > 0") + + assert(constraints(1).isInstanceOf[Check]) + val constraint2 = constraints(1).asInstanceOf[Check] + + assert(constraint2.name == "c2") + assert(constraint2.sql == "data='foo'") + assert(constraint2.predicate().toString() == "data = 'foo'") + } + } } From 3944f6043fc4f65224b1ca10da89f2346ed4dd03 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 5 Mar 2025 12:59:09 -0800 Subject: [PATCH 19/65] resolve constraints with a fake project --- .../catalyst/analysis/ResolveTableSpec.scala | 30 +++++++++++++++---- .../v2/CreateTableConstraintSuite.scala | 9 ++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index 3dcd76ee1e4f..00eb3ed5e641 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkThrowable -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.resolveExpressionByPlanOutput +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Constraints, Expression, Literal} import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -46,20 +47,30 @@ object ResolveTableSpec extends Rule[LogicalPlan] { preparedPlan.resolveOperatorsWithPruning(_.containsAnyPattern(COMMAND), ruleId) { case t: CreateTable => - resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s)) + resolveTableSpec(t, t.tableSpec, + fakeProjectFromColumns(t.columns), s => t.copy(tableSpec = s)) case t: CreateTableAsSelect => - resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s)) + resolveTableSpec(t, t.tableSpec, None, s => t.copy(tableSpec = s)) case t: ReplaceTable => - resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s)) + resolveTableSpec(t, t.tableSpec, + fakeProjectFromColumns(t.columns), s => t.copy(tableSpec = s)) case t: ReplaceTableAsSelect => - resolveTableSpec(t, t.tableSpec, s => t.copy(tableSpec = s)) + resolveTableSpec(t, t.tableSpec, None, s => t.copy(tableSpec = s)) } } + private def fakeProjectFromColumns(columns: Seq[ColumnDefinition]): Option[Project] = { + val fakeProjectList = columns.map { col => + AttributeReference(col.name, col.dataType)() + } + Some(Project(fakeProjectList, OneRowRelation())) + } + /** Helper method to resolve the table specification within a logical plan. */ private def resolveTableSpec( input: LogicalPlan, tableSpec: TableSpecBase, + fakeProject: Option[Project], withNewSpec: TableSpecBase => LogicalPlan): LogicalPlan = tableSpec match { case u: UnresolvedTableSpec if u.optionExpression.resolved => val newOptions: Seq[(String, String)] = u.optionExpression.options.map { @@ -86,6 +97,13 @@ object ResolveTableSpec extends Rule[LogicalPlan] { } (key, newValue) } + val newConstraints = if (fakeProject.isDefined) { + resolveExpressionByPlanOutput(u.constraints, fakeProject.get, throws = true) + .asInstanceOf[Constraints] + } else { + u.constraints + } + // assert(newConstraints.childrenResolved) val newTableSpec = TableSpec( properties = u.properties, provider = u.provider, @@ -95,7 +113,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] { collation = u.collation, serde = u.serde, external = u.external, - constraints = u.constraints.asConstraintList) + constraints = newConstraints.asConstraintList) withNewSpec(newTableSpec) case _ => input diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala index 26bc2ca8d9d0..7058888a65d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala @@ -65,4 +65,13 @@ class CreateTableConstraintSuite extends QueryTest with CommandSuiteBase with DD assert(constraint2.predicate().toString() == "data = 'foo'") } } + + test("Create table with UnresolvedAttribute in check constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql( + s""" + |CREATE TABLE $t (id bigint, data string) $defaultUsing + | CONSTRAINT c2 CHECK (abc = 'foo')""".stripMargin) + } + } } From 6b1a4af807d073bfc0993b0724c1032333eee9d0 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 5 Mar 2025 17:34:50 -0800 Subject: [PATCH 20/65] resolve constraint with a default analyzer --- .../catalyst/analysis/ResolveTableSpec.scala | 48 ++++++++++++++----- .../v2/CreateTableConstraintSuite.scala | 4 +- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index 00eb3ed5e641..2229a90183cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkThrowable -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.resolveExpressionByPlanOutput -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Constraints, Expression, Literal} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.DefaultColumnAnalyzer import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, MapType, StructType} @@ -48,29 +49,53 @@ object ResolveTableSpec extends Rule[LogicalPlan] { preparedPlan.resolveOperatorsWithPruning(_.containsAnyPattern(COMMAND), ruleId) { case t: CreateTable => resolveTableSpec(t, t.tableSpec, - fakeProjectFromColumns(t.columns), s => t.copy(tableSpec = s)) + fakeRelationFromColumns(t.columns), s => t.copy(tableSpec = s)) case t: CreateTableAsSelect => resolveTableSpec(t, t.tableSpec, None, s => t.copy(tableSpec = s)) case t: ReplaceTable => resolveTableSpec(t, t.tableSpec, - fakeProjectFromColumns(t.columns), s => t.copy(tableSpec = s)) + fakeRelationFromColumns(t.columns), s => t.copy(tableSpec = s)) case t: ReplaceTableAsSelect => resolveTableSpec(t, t.tableSpec, None, s => t.copy(tableSpec = s)) } } - private def fakeProjectFromColumns(columns: Seq[ColumnDefinition]): Option[Project] = { - val fakeProjectList = columns.map { col => + private def fakeRelationFromColumns(columns: Seq[ColumnDefinition]): Option[LogicalPlan] = { + val attributeList = columns.map { col => AttributeReference(col.name, col.dataType)() } - Some(Project(fakeProjectList, OneRowRelation())) + Some(LocalRelation(attributeList)) + } + + private def analyzeConstraints( + constraints: Constraints, + fakeRelation: LogicalPlan): Constraints = { + val analyzedExpressions = constraints.children.map { + case c: CheckConstraint => + val alias = Alias(c.child, c.name)() + val project = Project(Seq(alias), fakeRelation) + val analyzed = DefaultColumnAnalyzer.execute(project) + try { + DefaultColumnAnalyzer.checkAnalysis(analyzed) + } catch { + case e: AnalysisException => + throw e.withPosition(c.origin) + } + + val analyzedExpression = analyzed collectFirst { + case Project(Seq(Alias(e: Expression, _)), _) => e + } + c.withNewChildren(Seq(analyzedExpression.get)) + case other => other + } + Constraints(analyzedExpressions) } /** Helper method to resolve the table specification within a logical plan. */ private def resolveTableSpec( input: LogicalPlan, tableSpec: TableSpecBase, - fakeProject: Option[Project], + fakeRelation: Option[LogicalPlan], withNewSpec: TableSpecBase => LogicalPlan): LogicalPlan = tableSpec match { case u: UnresolvedTableSpec if u.optionExpression.resolved => val newOptions: Seq[(String, String)] = u.optionExpression.options.map { @@ -97,13 +122,12 @@ object ResolveTableSpec extends Rule[LogicalPlan] { } (key, newValue) } - val newConstraints = if (fakeProject.isDefined) { - resolveExpressionByPlanOutput(u.constraints, fakeProject.get, throws = true) - .asInstanceOf[Constraints] + val newConstraints = if (fakeRelation.isDefined) { + analyzeConstraints(u.constraints, fakeRelation.get) } else { u.constraints } - // assert(newConstraints.childrenResolved) + assert(newConstraints.childrenResolved) val newTableSpec = TableSpec( properties = u.properties, provider = u.provider, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala index 7058888a65d6..fa0016ef939b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala @@ -37,7 +37,7 @@ class CreateTableConstraintSuite extends QueryTest with CommandSuiteBase with DD assert(constraint.name == "c1") assert(constraint.sql == "id>0") - assert(constraint.predicate().toString() == "id > 0") + assert(constraint.predicate().toString() == "id > CAST(0 AS long)") } } @@ -55,7 +55,7 @@ class CreateTableConstraintSuite extends QueryTest with CommandSuiteBase with DD assert(constraint.name == "c1") assert(constraint.sql == "id>0") - assert(constraint.predicate().toString() == "id > 0") + assert(constraint.predicate().toString() == "id > CAST(0 AS long)") assert(constraints(1).isInstanceOf[Check]) val constraint2 = constraints(1).asInstanceOf[Check] From 46aa44890d06ff8de61cda8081c7f4168c6656dd Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 5 Mar 2025 23:02:13 -0800 Subject: [PATCH 21/65] improve error message --- .../sql/catalyst/analysis/ResolveTableSpec.scala | 8 +------- .../command/v2/CreateTableConstraintSuite.scala | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index 2229a90183cf..fc8982b93ff6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkThrowable -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime import org.apache.spark.sql.catalyst.plans.logical._ @@ -75,12 +74,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] { val alias = Alias(c.child, c.name)() val project = Project(Seq(alias), fakeRelation) val analyzed = DefaultColumnAnalyzer.execute(project) - try { - DefaultColumnAnalyzer.checkAnalysis(analyzed) - } catch { - case e: AnalysisException => - throw e.withPosition(c.origin) - } + DefaultColumnAnalyzer.checkAnalysis0(analyzed) val analyzedExpression = analyzed collectFirst { case Project(Seq(Alias(e: Expression, _)), _) => e diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala index fa0016ef939b..6404a719a932 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command.v2 -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.connector.catalog.Constraint.Check import org.apache.spark.sql.execution.command.DDLCommandTestUtils @@ -68,10 +68,19 @@ class CreateTableConstraintSuite extends QueryTest with CommandSuiteBase with DD test("Create table with UnresolvedAttribute in check constraint") { withNamespaceAndTable("ns", "tbl", catalog) { t => - sql( + val query = s""" |CREATE TABLE $t (id bigint, data string) $defaultUsing - | CONSTRAINT c2 CHECK (abc = 'foo')""".stripMargin) + | CONSTRAINT c2 CHECK (abc = 'foo')""".stripMargin + val e = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`abc`", "proposal" -> "`id`, `data`"), + sqlState = "42703", + context = ExpectedContext("abc", 89, 91)) // UnresolvedAttribute abc } } } From 18f7c655cd144294c7c2188921abee7e963c73bc Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 20 Mar 2025 13:09:30 -0700 Subject: [PATCH 22/65] move constraint class --- .../java/org/apache/spark/sql/connector/catalog/Table.java | 1 + .../org/apache/spark/sql/connector/catalog/TableCatalog.java | 1 + .../org/apache/spark/sql/connector/catalog/TableChange.java | 1 + .../sql/connector/catalog/{ => constraints}/Constraint.java | 2 +- .../apache/spark/sql/catalyst/expressions/constraints.scala | 2 +- .../apache/spark/sql/catalyst/plans/logical/v2Commands.scala | 1 + .../org/apache/spark/sql/connector/catalog/CatalogV2Util.scala | 1 + .../org/apache/spark/sql/connector/catalog/InMemoryTable.scala | 1 + .../spark/sql/connector/catalog/InMemoryTableCatalog.scala | 1 + .../scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala | 3 ++- .../spark/sql/execution/command/v2/CheckConstraintSuite.scala | 2 +- .../sql/execution/command/v2/CreateTableConstraintSuite.scala | 2 +- .../spark/sql/execution/command/v2/DropConstraintSuite.scala | 2 +- 13 files changed, 14 insertions(+), 6 deletions(-) rename sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/{ => constraints}/Constraint.java (97%) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java index 419cc01793d3..166554b0b4ca 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.catalog; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.constraints.Constraint; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.StructType; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index d882c14a335f..f98a72b0f9b6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.catalog; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.constraints.Constraint; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index 3fa3720b001e..43b8efaf55fc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -22,6 +22,7 @@ import javax.annotation.Nullable; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.constraints.Constraint; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.types.DataType; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java similarity index 97% rename from sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java index d6b15a452a9f..d53ce970655a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Constraint.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connector.catalog; +package org.apache.spark.sql.connector.catalog.constraints; import org.apache.spark.sql.connector.expressions.filter.Predicate; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index 128ecce247a9..bc46d0ca4d85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder -import org.apache.spark.sql.connector.catalog.Constraint +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.types.{DataType, StringType} trait ConstraintExpression extends Expression with Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index a9fccb87523b..914257a40c74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.expressions.filter.Predicate diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 0997aed99e8a..3dc8d16d3eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec} import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.TableChange._ +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.expressions.{ClusterByTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index d3eea9437a66..e8f2cf6979e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.catalog import java.util +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite, WriteBuilder, WriterCommitMessage} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 2cbbb389a7b2..5397e6cfd999 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NonEmptyNamespaceException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.util.CaseInsensitiveStringMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala index 4cc6c02333a2..cf5212351b52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql import java.util.Collections import org.apache.spark.{SparkConf, SparkRuntimeException} -import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, Constraint, Identifier, InMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, Identifier, InMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index e568e19b72ae..1dac4fd85380 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.connector.catalog.Constraint.Check import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.constraints.Constraint.Check import org.apache.spark.sql.execution.command.DDLCommandTestUtils class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala index 6404a719a932..0a9cd948a9f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.connector.catalog.Constraint.Check +import org.apache.spark.sql.connector.catalog.constraints.Constraint.Check import org.apache.spark.sql.execution.command.DDLCommandTestUtils class CreateTableConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala index 74633f4f8e1a..368bf71c53d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.connector.catalog.Constraint.Check +import org.apache.spark.sql.connector.catalog.constraints.Constraint.Check import org.apache.spark.sql.execution.command.DDLCommandTestUtils class DropConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { From 3acca28f0d901c6650b2deb3368dd3253862c951 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 20 Mar 2025 13:13:29 -0700 Subject: [PATCH 23/65] remove Constraint.java --- .../catalog/constraints/Constraint.java | 67 ------------------- 1 file changed, 67 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java deleted file mode 100644 index d53ce970655a..000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.connector.catalog.constraints; - -import org.apache.spark.sql.connector.expressions.filter.Predicate; - -public interface Constraint { - String name(); // either assigned by the user or auto generated - String description(); // used in toString() - String toDDL(); // used in EXPLAIN/DESCRIBE/SHOW CREATE TABLE - boolean rely(); // indicates whether the constraint is believed to be true - boolean enforced(); // indicates whether the constraint must be enforced - - static Constraint check(String name, String sql, Predicate predicate) { - return new Check(name, sql, predicate); - } - - final class Check implements Constraint { - private final String name; - private final String sql; - private final Predicate predicate; - private Check(String name, String sql, Predicate predicate) { - this.name = name; - this.sql = sql; - this.predicate = predicate; - } - - @Override public String name() { - return name; - } - - @Override public String description() { - return "check constraint"; - } - - @Override - public String toDDL() { - return "CHECK (" + sql + ")"; - } - - @Override public boolean rely() { - return true; - } - - @Override public boolean enforced() { - return true; - } - - public String sql() { return sql; } - - public Predicate predicate() { return predicate; } - } -} From a0ab768229ed580e2ae677d45d18701cc2fc14ba Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Tue, 11 Mar 2025 08:24:43 -0700 Subject: [PATCH 24/65] [SPARK-51441][SQL] Add DSv2 APIs for constraints --- .../catalog/constraints/BaseConstraint.java | 71 ++++++++++ .../connector/catalog/constraints/Check.java | 90 ++++++++++++ .../catalog/constraints/Constraint.java | 128 ++++++++++++++++++ .../catalog/constraints/ConstraintState.java | 116 ++++++++++++++++ .../catalog/constraints/ForeignKey.java | 104 ++++++++++++++ .../catalog/constraints/PrimaryKey.java | 74 ++++++++++ .../connector/catalog/constraints/Unique.java | 73 ++++++++++ .../connector/catalog/ConstraintSuite.scala | 106 +++++++++++++++ 8 files changed, 762 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ConstraintState.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java new file mode 100644 index 000000000000..2eb80a99d552 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import org.apache.spark.sql.connector.expressions.NamedReference; + +import java.util.StringJoiner; + +abstract class BaseConstraint implements Constraint { + + private final String name; + private final ConstraintState state; + + protected BaseConstraint(String name, ConstraintState state) { + this.name = name; + this.state = state; + } + + protected abstract String definition(); + + @Override + public String name() { + return name; + } + + @Override + public ConstraintState state() { + return state; + } + + @Override + public String toDDL() { + return String.format( + "CONSTRAINT %s %s %s %s %s", + name, + definition(), + state.enforced() ? "ENFORCED" : "NOT ENFORCED", + state.validated() ? "VALID" : "NOT VALID", + state.rely() ? "RELY" : "NORELY"); + } + + @Override + public String toString() { + return toDDL(); + } + + protected String toSQL(NamedReference[] columns) { + StringJoiner joiner = new StringJoiner(", "); + + for (NamedReference column : columns) { + joiner.add(column.toString()); + } + + return joiner.toString(); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java new file mode 100644 index 000000000000..6f05f7450fc5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import org.apache.spark.sql.connector.expressions.filter.Predicate; + +import java.util.Objects; + +/** + * A CHECK constraint. + *

+ * A CHECK constraint defines a condition each row in a table must satisfy. Connectors can define + * such constraints either in SQL (Spark SQL dialect) or using a {@link Predicate predicate} if the + * condition can be expressed using a supported expression. A CHECK constraint can reference one or + * more columns. Such constraint is considered violated if its condition evaluates to {@code FALSE} + * (not {@code NULL}). The search condition must be deterministic and cannot contain subqueries and + * certain functions like aggregates. + * + * @since 4.1.0 + */ +public class Check extends BaseConstraint { + + private final String sql; + private final Predicate predicate; + + Check( + String name, + String sql, + Predicate predicate, + ConstraintState state) { + super(name, state); + + if (sql == null && predicate == null) { + throw new IllegalArgumentException("SQL and predicate can't be both null"); + } + + this.sql = sql; + this.predicate = predicate; + } + + /** + * Returns the SQL representation of the search condition (Spark SQL dialect). + */ + public String sql() { + return sql; + } + + /** + * Returns the search condition. + */ + public Predicate predicate() { + return predicate; + } + + @Override + protected String definition() { + return String.format("CHECK %s", sql != null ? sql : predicate); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + Check that = (Check) other; + return Objects.equals(name(), that.name()) && + Objects.equals(sql, that.sql) && + Objects.equals(predicate, that.predicate) && + Objects.equals(state(), that.state()); + } + + @Override + public int hashCode() { + return Objects.hash(name(), sql, predicate, state()); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java new file mode 100644 index 000000000000..3272634c72e4 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; + +/** + * A constraint that defines valid states of data in a table. + * + * @since 4.1.0 + */ +@Evolving +public interface Constraint { + /** + * Returns the name of this constraint. + */ + String name(); + + /** + * Returns the definition of this constraint in DDL format. + */ + String toDDL(); + + /** + * Returns the state of this constraint. + */ + ConstraintState state(); + + /** + * Creates a CHECK constraint with a search condition defined in SQL (Spark SQL dialect). + * + * @param name the constraint name + * @param sql the SQL representation of the search condition (Spark SQL dialect) + * @param state the constraint state + * @return a CHECK constraint with the provided configuration + */ + static Check check(String name, String sql, ConstraintState state) { + return new Check(name, sql, null /* no predicate */, state); + } + + /** + * Creates a CHECK constraint with a search condition defined by a {@link Predicate predicate}. + * + * @param name the constraint name + * @param predicate the search condition + * @param state the constraint state + * @return a CHECK constraint with the provided configuration + */ + static Check check(String name, Predicate predicate, ConstraintState state) { + return new Check(name, null /* no SQL */, predicate, state); + } + + /** + * Creates a CHECK constraint with a search condition defined in SQL (Spark SQL dialect) and + * by {@link Predicate} (if the SQL representation can be converted into a supported expression). + * The SQL string and predicate must be equivalent. + * + * @param name the constraint name + * @param sql the SQL representation of the search condition (Spark SQL dialect) + * @param predicate the search condition + * @param state the constraint state + * @return a CHECK constraint with the provided configuration + */ + static Check check(String name, String sql, Predicate predicate, ConstraintState state) { + return new Check(name, sql, predicate, state); + } + + /** + * Creates a UNIQUE constraint. + * + * @param name the constraint name + * @param columns the columns that comprise the unique key + * @param state the constraint state + * @return a UNIQUE constraint with the provided configuration + */ + static Unique unique(String name, NamedReference[] columns, ConstraintState state) { + return new Unique(name, columns, state); + } + + /** + * Creates a PRIMARY KEY constraint. + * + * @param name the constraint name + * @param columns the columns that comprise the primary key + * @param state the constraint state + * @return a PRIMARY KEY constraint with the provided configuration + */ + static PrimaryKey primaryKey(String name, NamedReference[] columns, ConstraintState state) { + return new PrimaryKey(name, columns, state); + } + + /** + * Creates a FOREIGN KEY constraint. + * + * @param name the constraint name + * @param columns the referencing columns + * @param refTable the referenced table identifier + * @param refColumns the referenced columns in the referenced table + * @param state the constraint state + * @return a FOREIGN KEY constraint with the provided configuration + */ + static ForeignKey foreignKey( + String name, + NamedReference[] columns, + Identifier refTable, + NamedReference[] refColumns, + ConstraintState state) { + return new ForeignKey(name, columns, refTable, refColumns, state); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ConstraintState.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ConstraintState.java new file mode 100644 index 000000000000..5c0b8b0f97a6 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ConstraintState.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import java.util.Objects; + +/** + * A constraint state with the following properties: + *

    + *
  • Enforced: Indicates whether the constraint is actively enforced. If + * enforced, data modification operations that violate the constraint fail with a constraint + * violation error.
  • + * + *
  • Validated: Indicates whether the existing data in the table satisfies + * the constraint. A constraint may be validated independently from enforcement, meaning it can + * be validated without being actively enforced, or vice versa.
  • + * + *
  • Rely: Indicates whether the constraint is assumed to hold true even if it + * is not validated. The reliance state allows query optimizers to utilize the constraint for + * optimization purposes.
  • + *
+ * + * @since 4.1.0 + */ +public class ConstraintState { + + private final boolean enforced; + private final boolean validated; + private final boolean rely; + + private ConstraintState(boolean enforced, boolean validated, boolean rely) { + this.enforced = enforced; + this.validated = validated; + this.rely = rely; + } + + /** + * Indicates whether the constraint is actively enforced. + */ + public boolean enforced() { + return enforced; + } + + /** + * Indicates whether the existing data is known to satisfy the constraint. + */ + public boolean validated() { + return validated; + } + + /** + * Indicates whether the constraint is assumed to hold true. + */ + public boolean rely() { + return rely; + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + ConstraintState that = (ConstraintState) other; + return enforced == that.enforced && validated == that.validated && rely == that.rely; + } + + @Override + public int hashCode() { + return Objects.hash(enforced, validated, rely); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private boolean enforced; + private boolean validated; + private boolean rely; + + private Builder() {} + + public Builder enforced(boolean value) { + this.enforced = value; + return this; + } + + public Builder validated(boolean value) { + this.validated = value; + return this; + } + + public Builder rely(boolean value) { + this.rely = value; + return this; + } + + public ConstraintState build() { + return new ConstraintState(enforced, validated, rely); + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java new file mode 100644 index 000000000000..7a1f9bd2ae21 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.NamedReference; + +import java.util.Arrays; +import java.util.Objects; + +/** + * A FOREIGN KEY constraint. + *

+ * A FOREIGN KEY constraint specifies one or more columns (referencing columns) in a table that + * refer to corresponding columns (referenced columns) in another table. The referenced columns + * must form a UNIQUE or PRIMARY KEY constraint in the referenced table. For this constraint to be + * satisfied, each row in the table must contain values in the referencing columns that exactly + * match values of a row in the referenced table. + * + * @since 4.1.0 + */ +public class ForeignKey extends BaseConstraint { + + private final NamedReference[] columns; + private final Identifier refTable; + private final NamedReference[] refColumns; + + ForeignKey( + String name, + NamedReference[] columns, + Identifier refTable, + NamedReference[] refColumns, + ConstraintState state) { + super(name, state); + this.columns = columns; + this.refTable = refTable; + this.refColumns = refColumns; + } + + /** + * Returns the referencing columns. + */ + public NamedReference[] columns() { + return columns; + } + + /** + * Returns the referenced table. + */ + public Identifier referencedTable() { + return refTable; + } + + /** + * Returns the referenced columns in the referenced table. + */ + public NamedReference[] referencedColumns() { + return refColumns; + } + + @Override + protected String definition() { + return String.format( + "FOREIGN KEY (%s) REFERENCES %s (%s)", + toSQL(columns), + refTable.toString(), + toSQL(refColumns)); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + ForeignKey that = (ForeignKey) other; + return Objects.equals(name(), that.name()) && + Arrays.equals(columns, that.columns) && + Objects.equals(refTable, that.refTable) && + Arrays.equals(refColumns, that.refColumns) && + Objects.equals(state(), that.state()); + } + + @Override + public int hashCode() { + int result = Objects.hash(name(), refTable, state()); + result = 31 * result + Arrays.hashCode(columns); + result = 31 * result + Arrays.hashCode(refColumns); + return result; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java new file mode 100644 index 000000000000..79be39abdc0e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import org.apache.spark.sql.connector.expressions.NamedReference; + +import java.util.Arrays; +import java.util.Objects; + +/** + * A PRIMARY KEY constraint. + *

+ * A PRIMARY KEY constraint specifies ore or more columns as a primary key. Such constraint is + * satisfied if and only if no two rows in a table have the same non-null values in the primary + * key columns and none of the values in the specified column or columns are {@code NULL}. + * + * @since 4.1.0 + */ +public class PrimaryKey extends BaseConstraint { + + private final NamedReference[] columns; + + PrimaryKey( + String name, + NamedReference[] columns, + ConstraintState state) { + super(name, state); + this.columns = columns; + } + + /** + * Returns the columns that comprise the primary key. + */ + public NamedReference[] columns() { + return columns; + } + + @Override + protected String definition() { + return String.format("PRIMARY KEY (%s)", toSQL(columns)); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + PrimaryKey that = (PrimaryKey) other; + return Objects.equals(name(), that.name()) && + Arrays.equals(columns, that.columns()) && + Objects.equals(state(), that.state()); + } + + @Override + public int hashCode() { + int result = Objects.hash(name(), state()); + result = 31 * result + Arrays.hashCode(columns); + return result; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java new file mode 100644 index 000000000000..4a87f54341dc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.constraints; + +import org.apache.spark.sql.connector.expressions.NamedReference; + +import java.util.Arrays; +import java.util.Objects; + +/** + * A UNIQUE constraint. + *

+ * A UNIQUE constraint specifies one or more columns as unique columns. Such constraint is satisfied + * if and only if no two rows in a table have the same non-null values in the unique columns. + * + * @since 4.1.0 + */ +public class Unique extends BaseConstraint { + + private final NamedReference[] columns; + + Unique( + String name, + NamedReference[] columns, + ConstraintState state) { + super(name, state); + this.columns = columns; + } + + /** + * Returns the columns that comprise the unique key. + */ + public NamedReference[] columns() { + return columns; + } + + @Override + protected String definition() { + return String.format("UNIQUE (%s)", toSQL(columns)); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + Unique that = (Unique) other; + return Objects.equals(name(), that.name()) && + Arrays.equals(columns, that.columns()) && + Objects.equals(state(), that.state()); + } + + @Override + public int hashCode() { + int result = Objects.hash(name(), state()); + result = 31 * result + Arrays.hashCode(columns); + return result; + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala new file mode 100644 index 000000000000..c0ec445aecb4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.catalog.constraints.{Constraint, ConstraintState} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, LiteralValue, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.types.IntegerType + +class ConstraintSuite extends SparkFunSuite { + + test("CHECK constraint toDDL") { + val con1 = Constraint.check( + "con1", + "id > 10", + ConstraintState.builder().enforced(true).validated(true).rely(true).build()) + assert(con1.toDDL == "CONSTRAINT con1 CHECK id > 10 ENFORCED VALID RELY") + + val con2 = Constraint.check( + "con2", + new Predicate( + "=", + Array[Expression]( + FieldReference(Seq("a", "b.c", "d")), + LiteralValue(1, IntegerType))), + ConstraintState.builder().enforced(false).validated(true).rely(true).build()) + assert(con2.toDDL == "CONSTRAINT con2 CHECK a.`b.c`.d = 1 NOT ENFORCED VALID RELY") + + val con3 = Constraint.check( + "con3", + "a.b.c <=> 1", + new Predicate( + "<=>", + Array[Expression]( + FieldReference(Seq("a", "b", "c")), + LiteralValue(1, IntegerType))), + ConstraintState.builder().enforced(false).validated(false).rely(false).build()) + assert(con3.toDDL == "CONSTRAINT con3 CHECK a.b.c <=> 1 NOT ENFORCED NOT VALID NORELY") + } + + test("UNIQUE constraint toDDL") { + val con1 = Constraint.unique( + "con1", + Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d"))), + ConstraintState.builder().enforced(false).validated(false).rely(true).build()) + assert(con1.toDDL == "CONSTRAINT con1 UNIQUE (a.b.c, d) NOT ENFORCED NOT VALID RELY") + + val con2 = Constraint.unique( + "con2", + Array[NamedReference](FieldReference(Seq("a.b", "x", "y")), FieldReference(Seq("d"))), + ConstraintState.builder().enforced(false).validated(true).rely(true).build()) + assert(con2.toDDL == "CONSTRAINT con2 UNIQUE (`a.b`.x.y, d) NOT ENFORCED VALID RELY") + } + + test("PRIMARY KEY constraint toDDL") { + val pk1 = Constraint.primaryKey( + "pk1", + Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d"))), + ConstraintState.builder().enforced(true).validated(true).rely(true).build()) + assert(pk1.toDDL == "CONSTRAINT pk1 PRIMARY KEY (a.b.c, d) ENFORCED VALID RELY") + + val pk2 = Constraint.primaryKey( + "pk2", + Array[NamedReference](FieldReference(Seq("x.y", "z")), FieldReference(Seq("id"))), + ConstraintState.builder().enforced(false).validated(false).rely(false).build()) + assert(pk2.toDDL == "CONSTRAINT pk2 PRIMARY KEY (`x.y`.z, id) NOT ENFORCED NOT VALID NORELY") + } + + test("FOREIGN KEY constraint toDDL") { + val fk1 = Constraint.foreignKey( + "fk1", + Array[NamedReference](FieldReference(Seq("col1")), FieldReference(Seq("col2"))), + Identifier.of(Array("schema"), "table"), + Array[NamedReference](FieldReference(Seq("ref_col1")), FieldReference(Seq("ref_col2"))), + ConstraintState.builder().enforced(true).validated(true).rely(true).build()) + assert(fk1.toDDL == "CONSTRAINT fk1 FOREIGN KEY (col1, col2) " + + "REFERENCES schema.table (ref_col1, ref_col2) " + + "ENFORCED VALID RELY") + + val fk2 = Constraint.foreignKey( + "fk2", + Array[NamedReference](FieldReference(Seq("x.y", "z"))), + Identifier.of(Array.empty[String], "other_table"), + Array[NamedReference](FieldReference(Seq("other_id"))), + ConstraintState.builder().enforced(false).validated(false).rely(false).build()) + assert(fk2.toDDL == "CONSTRAINT fk2 FOREIGN KEY (`x.y`.z) " + + "REFERENCES other_table (other_id) " + + "NOT ENFORCED NOT VALID NORELY") + } +} From d029a9964e11581f0599bcb275daad2982730a10 Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Thu, 20 Mar 2025 12:00:24 -0700 Subject: [PATCH 25/65] Flatten the structure, add builders --- .../catalog/constraints/BaseConstraint.java | 84 +++++++++++-- .../connector/catalog/constraints/Check.java | 65 +++++++--- .../catalog/constraints/Constraint.java | 95 +++++++------- .../catalog/constraints/ConstraintState.java | 116 ----------------- .../catalog/constraints/ForeignKey.java | 55 ++++++-- .../catalog/constraints/PrimaryKey.java | 37 +++++- .../connector/catalog/constraints/Unique.java | 37 ++++-- .../connector/catalog/ConstraintSuite.scala | 118 +++++++++++------- 8 files changed, 349 insertions(+), 258 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ConstraintState.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java index 2eb80a99d552..3d5bd5afe8aa 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java @@ -17,18 +17,26 @@ package org.apache.spark.sql.connector.catalog.constraints; -import org.apache.spark.sql.connector.expressions.NamedReference; - import java.util.StringJoiner; +import org.apache.spark.sql.connector.expressions.NamedReference; + abstract class BaseConstraint implements Constraint { private final String name; - private final ConstraintState state; + private final boolean enforced; + private final ValidationStatus validationStatus; + private final boolean rely; - protected BaseConstraint(String name, ConstraintState state) { + protected BaseConstraint( + String name, + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { this.name = name; - this.state = state; + this.enforced = enforced; + this.validationStatus = validationStatus; + this.rely = rely; } protected abstract String definition(); @@ -39,8 +47,18 @@ public String name() { } @Override - public ConstraintState state() { - return state; + public boolean enforced() { + return enforced; + } + + @Override + public ValidationStatus validationStatus() { + return validationStatus; + } + + @Override + public boolean rely() { + return rely; } @Override @@ -49,9 +67,9 @@ public String toDDL() { "CONSTRAINT %s %s %s %s %s", name, definition(), - state.enforced() ? "ENFORCED" : "NOT ENFORCED", - state.validated() ? "VALID" : "NOT VALID", - state.rely() ? "RELY" : "NORELY"); + enforced ? "ENFORCED" : "NOT ENFORCED", + validationStatus, + rely ? "RELY" : "NORELY"); } @Override @@ -68,4 +86,50 @@ protected String toSQL(NamedReference[] columns) { return joiner.toString(); } + + abstract static class Builder { + private final String name; + private boolean enforced = true; + private ValidationStatus validationStatus = ValidationStatus.UNVALIDATED; + private boolean rely = false; + + Builder(String name) { + this.name = name; + } + + protected abstract B self(); + + public abstract C build(); + + public String name() { + return name; + } + + public B enforced(boolean enforced) { + this.enforced = enforced; + return self(); + } + + public boolean enforced() { + return enforced; + } + + public B validationStatus(ValidationStatus validationStatus) { + this.validationStatus = validationStatus; + return self(); + } + + public ValidationStatus validationStatus() { + return validationStatus; + } + + public B rely(boolean rely) { + this.rely = rely; + return self(); + } + + public boolean rely() { + return rely; + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java index 6f05f7450fc5..fbec7651d251 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java @@ -17,19 +17,21 @@ package org.apache.spark.sql.connector.catalog.constraints; -import org.apache.spark.sql.connector.expressions.filter.Predicate; - +import java.util.Map; import java.util.Objects; +import org.apache.spark.SparkIllegalArgumentException; +import org.apache.spark.sql.connector.expressions.filter.Predicate; + /** * A CHECK constraint. *

* A CHECK constraint defines a condition each row in a table must satisfy. Connectors can define * such constraints either in SQL (Spark SQL dialect) or using a {@link Predicate predicate} if the * condition can be expressed using a supported expression. A CHECK constraint can reference one or - * more columns. Such constraint is considered violated if its condition evaluates to {@code FALSE} - * (not {@code NULL}). The search condition must be deterministic and cannot contain subqueries and - * certain functions like aggregates. + * more columns. Such constraint is considered violated if its condition evaluates to {@code FALSE}, + * but not {@code NULL}. The search condition must be deterministic and cannot contain subqueries + * and certain functions like aggregates or UDFs. * * @since 4.1.0 */ @@ -38,17 +40,14 @@ public class Check extends BaseConstraint { private final String sql; private final Predicate predicate; - Check( + private Check( String name, String sql, Predicate predicate, - ConstraintState state) { - super(name, state); - - if (sql == null && predicate == null) { - throw new IllegalArgumentException("SQL and predicate can't be both null"); - } - + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { + super(name, enforced, validationStatus, rely); this.sql = sql; this.predicate = predicate; } @@ -80,11 +79,47 @@ public boolean equals(Object other) { return Objects.equals(name(), that.name()) && Objects.equals(sql, that.sql) && Objects.equals(predicate, that.predicate) && - Objects.equals(state(), that.state()); + enforced() == that.enforced() && + Objects.equals(validationStatus(), that.validationStatus()) && + rely() == that.rely(); } @Override public int hashCode() { - return Objects.hash(name(), sql, predicate, state()); + return Objects.hash(name(), sql, predicate, enforced(), validationStatus(), rely()); + } + + public static class Builder extends BaseConstraint.Builder { + + private String sql; + private Predicate predicate; + + Builder(String name) { + super(name); + } + + @Override + protected Builder self() { + return this; + } + + public Builder sql(String sql) { + this.sql = sql; + return this; + } + + public Builder predicate(Predicate predicate) { + this.predicate = predicate; + return this; + } + + public Check build() { + if (sql == null && predicate == null) { + throw new SparkIllegalArgumentException( + "INTERNAL_ERROR", + Map.of("message", "SQL and predicate in CHECK can't be both null")); + } + return new Check(name(), sql, predicate, enforced(), validationStatus(), rely()); + } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java index 3272634c72e4..f8381326a205 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java @@ -20,10 +20,9 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.expressions.NamedReference; -import org.apache.spark.sql.connector.expressions.filter.Predicate; /** - * A constraint that defines valid states of data in a table. + * A constraint that restricts states of data in a table. * * @since 4.1.0 */ @@ -35,94 +34,86 @@ public interface Constraint { String name(); /** - * Returns the definition of this constraint in DDL format. + * Indicates whether this constraint is actively enforced. If enforced, data modifications + * that violate the constraint fail with a constraint violation error. */ - String toDDL(); + boolean enforced(); /** - * Returns the state of this constraint. + * Indicates whether the existing data in the table satisfies this constraint. The constraint + * can be valid (the data is guaranteed to satisfy the constraint), invalid (some records violate + * the constraint), or unvalidated (the validity is unknown). */ - ConstraintState state(); + ValidationStatus validationStatus(); /** - * Creates a CHECK constraint with a search condition defined in SQL (Spark SQL dialect). - * - * @param name the constraint name - * @param sql the SQL representation of the search condition (Spark SQL dialect) - * @param state the constraint state - * @return a CHECK constraint with the provided configuration + * Indicates whether this constraint is assumed to hold true if the validity is unknown. */ - static Check check(String name, String sql, ConstraintState state) { - return new Check(name, sql, null /* no predicate */, state); - } + boolean rely(); /** - * Creates a CHECK constraint with a search condition defined by a {@link Predicate predicate}. - * - * @param name the constraint name - * @param predicate the search condition - * @param state the constraint state - * @return a CHECK constraint with the provided configuration + * Returns the definition of this constraint in the DDL format. */ - static Check check(String name, Predicate predicate, ConstraintState state) { - return new Check(name, null /* no SQL */, predicate, state); - } + String toDDL(); /** - * Creates a CHECK constraint with a search condition defined in SQL (Spark SQL dialect) and - * by {@link Predicate} (if the SQL representation can be converted into a supported expression). - * The SQL string and predicate must be equivalent. + * Instantiates a builder for a CHECK constraint. * * @param name the constraint name - * @param sql the SQL representation of the search condition (Spark SQL dialect) - * @param predicate the search condition - * @param state the constraint state - * @return a CHECK constraint with the provided configuration + * @return a CHECK constraint builder */ - static Check check(String name, String sql, Predicate predicate, ConstraintState state) { - return new Check(name, sql, predicate, state); + static Check.Builder check(String name) { + return new Check.Builder(name); } /** - * Creates a UNIQUE constraint. + * Instantiates a builder for a UNIQUE constraint. * * @param name the constraint name - * @param columns the columns that comprise the unique key - * @param state the constraint state - * @return a UNIQUE constraint with the provided configuration + * @param columns columns that comprise the unique key + * @return a UNIQUE constraint builder */ - static Unique unique(String name, NamedReference[] columns, ConstraintState state) { - return new Unique(name, columns, state); + static Unique.Builder unique(String name, NamedReference[] columns) { + return new Unique.Builder(name, columns); } /** - * Creates a PRIMARY KEY constraint. + * Instantiates a builder for a PRIMARY KEY constraint. * * @param name the constraint name - * @param columns the columns that comprise the primary key - * @param state the constraint state - * @return a PRIMARY KEY constraint with the provided configuration + * @param columns columns that comprise the primary key + * @return a PRIMARY KEY constraint builder */ - static PrimaryKey primaryKey(String name, NamedReference[] columns, ConstraintState state) { - return new PrimaryKey(name, columns, state); + static PrimaryKey.Builder primaryKey(String name, NamedReference[] columns) { + return new PrimaryKey.Builder(name, columns); } /** - * Creates a FOREIGN KEY constraint. + * Instantiates a builder for a FOREIGN KEY constraint. * * @param name the constraint name * @param columns the referencing columns * @param refTable the referenced table identifier * @param refColumns the referenced columns in the referenced table - * @param state the constraint state - * @return a FOREIGN KEY constraint with the provided configuration + * @return a FOREIGN KEY constraint builder */ - static ForeignKey foreignKey( + static ForeignKey.Builder foreignKey( String name, NamedReference[] columns, Identifier refTable, - NamedReference[] refColumns, - ConstraintState state) { - return new ForeignKey(name, columns, refTable, refColumns, state); + NamedReference[] refColumns) { + return new ForeignKey.Builder(name, columns, refTable, refColumns); + } + + /** + * An indicator of the validity of the constraint. + *

+ * A constraint may be validated independently of enforcement, meaning it can be validated + * without being actively enforced, or vice versa. A constraint can be valid (the data is + * guaranteed to satisfy the constraint), invalid (some records violate the constraint), + * or unvalidated (the validity is unknown). + */ + enum ValidationStatus { + VALID, INVALID, UNVALIDATED } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ConstraintState.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ConstraintState.java deleted file mode 100644 index 5c0b8b0f97a6..000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ConstraintState.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.catalog.constraints; - -import java.util.Objects; - -/** - * A constraint state with the following properties: - *

    - *
  • Enforced: Indicates whether the constraint is actively enforced. If - * enforced, data modification operations that violate the constraint fail with a constraint - * violation error.
  • - * - *
  • Validated: Indicates whether the existing data in the table satisfies - * the constraint. A constraint may be validated independently from enforcement, meaning it can - * be validated without being actively enforced, or vice versa.
  • - * - *
  • Rely: Indicates whether the constraint is assumed to hold true even if it - * is not validated. The reliance state allows query optimizers to utilize the constraint for - * optimization purposes.
  • - *
- * - * @since 4.1.0 - */ -public class ConstraintState { - - private final boolean enforced; - private final boolean validated; - private final boolean rely; - - private ConstraintState(boolean enforced, boolean validated, boolean rely) { - this.enforced = enforced; - this.validated = validated; - this.rely = rely; - } - - /** - * Indicates whether the constraint is actively enforced. - */ - public boolean enforced() { - return enforced; - } - - /** - * Indicates whether the existing data is known to satisfy the constraint. - */ - public boolean validated() { - return validated; - } - - /** - * Indicates whether the constraint is assumed to hold true. - */ - public boolean rely() { - return rely; - } - - @Override - public boolean equals(Object other) { - if (this == other) return true; - if (other == null || getClass() != other.getClass()) return false; - ConstraintState that = (ConstraintState) other; - return enforced == that.enforced && validated == that.validated && rely == that.rely; - } - - @Override - public int hashCode() { - return Objects.hash(enforced, validated, rely); - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private boolean enforced; - private boolean validated; - private boolean rely; - - private Builder() {} - - public Builder enforced(boolean value) { - this.enforced = value; - return this; - } - - public Builder validated(boolean value) { - this.validated = value; - return this; - } - - public Builder rely(boolean value) { - this.rely = value; - return this; - } - - public ConstraintState build() { - return new ConstraintState(enforced, validated, rely); - } - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java index 7a1f9bd2ae21..4763f95ba98b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java @@ -17,12 +17,12 @@ package org.apache.spark.sql.connector.catalog.constraints; -import org.apache.spark.sql.connector.catalog.Identifier; -import org.apache.spark.sql.connector.expressions.NamedReference; - import java.util.Arrays; import java.util.Objects; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.expressions.NamedReference; + /** * A FOREIGN KEY constraint. *

@@ -45,8 +45,10 @@ public class ForeignKey extends BaseConstraint { NamedReference[] columns, Identifier refTable, NamedReference[] refColumns, - ConstraintState state) { - super(name, state); + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { + super(name, enforced, validationStatus, rely); this.columns = columns; this.refTable = refTable; this.refColumns = refColumns; @@ -78,7 +80,7 @@ protected String definition() { return String.format( "FOREIGN KEY (%s) REFERENCES %s (%s)", toSQL(columns), - refTable.toString(), + refTable, toSQL(refColumns)); } @@ -91,14 +93,51 @@ public boolean equals(Object other) { Arrays.equals(columns, that.columns) && Objects.equals(refTable, that.refTable) && Arrays.equals(refColumns, that.refColumns) && - Objects.equals(state(), that.state()); + enforced() == that.enforced() && + Objects.equals(validationStatus(), that.validationStatus()) && + rely() == that.rely(); } @Override public int hashCode() { - int result = Objects.hash(name(), refTable, state()); + int result = Objects.hash(name(), refTable, enforced(), validationStatus(), rely()); result = 31 * result + Arrays.hashCode(columns); result = 31 * result + Arrays.hashCode(refColumns); return result; } + + public static class Builder extends BaseConstraint.Builder { + + private final NamedReference[] columns; + private final Identifier refTable; + private final NamedReference[] refColumns; + + public Builder( + String name, + NamedReference[] columns, + Identifier refTable, + NamedReference[] refColumns) { + super(name); + this.columns = columns; + this.refTable = refTable; + this.refColumns = refColumns; + } + + @Override + protected Builder self() { + return this; + } + + @Override + public ForeignKey build() { + return new ForeignKey( + name(), + columns, + refTable, + refColumns, + enforced(), + validationStatus(), + rely()); + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java index 79be39abdc0e..caaf29c10538 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java @@ -17,17 +17,18 @@ package org.apache.spark.sql.connector.catalog.constraints; -import org.apache.spark.sql.connector.expressions.NamedReference; - import java.util.Arrays; import java.util.Objects; +import org.apache.spark.sql.connector.expressions.NamedReference; + /** * A PRIMARY KEY constraint. *

* A PRIMARY KEY constraint specifies ore or more columns as a primary key. Such constraint is * satisfied if and only if no two rows in a table have the same non-null values in the primary * key columns and none of the values in the specified column or columns are {@code NULL}. + * A table can have at most one primary key. * * @since 4.1.0 */ @@ -38,8 +39,10 @@ public class PrimaryKey extends BaseConstraint { PrimaryKey( String name, NamedReference[] columns, - ConstraintState state) { - super(name, state); + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { + super(name, enforced, validationStatus, rely); this.columns = columns; } @@ -62,13 +65,35 @@ public boolean equals(Object other) { PrimaryKey that = (PrimaryKey) other; return Objects.equals(name(), that.name()) && Arrays.equals(columns, that.columns()) && - Objects.equals(state(), that.state()); + enforced() == that.enforced() && + Objects.equals(validationStatus(), that.validationStatus()) && + rely() == that.rely(); } @Override public int hashCode() { - int result = Objects.hash(name(), state()); + int result = Objects.hash(name(), enforced(), validationStatus(), rely()); result = 31 * result + Arrays.hashCode(columns); return result; } + + public static class Builder extends BaseConstraint.Builder { + + private final NamedReference[] columns; + + Builder(String name, NamedReference[] columns) { + super(name); + this.columns = columns; + } + + @Override + protected Builder self() { + return this; + } + + @Override + public PrimaryKey build() { + return new PrimaryKey(name(), columns, enforced(), validationStatus(), rely()); + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java index 4a87f54341dc..394ad6b814e6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java @@ -17,11 +17,11 @@ package org.apache.spark.sql.connector.catalog.constraints; -import org.apache.spark.sql.connector.expressions.NamedReference; - import java.util.Arrays; import java.util.Objects; +import org.apache.spark.sql.connector.expressions.NamedReference; + /** * A UNIQUE constraint. *

@@ -34,11 +34,13 @@ public class Unique extends BaseConstraint { private final NamedReference[] columns; - Unique( + private Unique( String name, NamedReference[] columns, - ConstraintState state) { - super(name, state); + boolean enforced, + ValidationStatus validationStatus, + boolean rely) { + super(name, enforced, validationStatus, rely); this.columns = columns; } @@ -61,13 +63,34 @@ public boolean equals(Object other) { Unique that = (Unique) other; return Objects.equals(name(), that.name()) && Arrays.equals(columns, that.columns()) && - Objects.equals(state(), that.state()); + enforced() == that.enforced() && + Objects.equals(validationStatus(), that.validationStatus()) && + rely() == that.rely(); } @Override public int hashCode() { - int result = Objects.hash(name(), state()); + int result = Objects.hash(name(), enforced(), validationStatus(), rely()); result = 31 * result + Arrays.hashCode(columns); return result; } + + public static class Builder extends BaseConstraint.Builder { + + private final NamedReference[] columns; + + Builder(String name, NamedReference[] columns) { + super(name); + this.columns = columns; + } + + @Override + protected Builder self() { + return this; + } + + public Unique build() { + return new Unique(name(), columns, enforced(), validationStatus(), rely()); + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala index c0ec445aecb4..6b4bea3b14cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.connector.catalog.constraints.{Constraint, ConstraintState} +import org.apache.spark.sql.connector.catalog.constraints.Constraint +import org.apache.spark.sql.connector.catalog.constraints.Constraint.ValidationStatus import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, LiteralValue, NamedReference} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.types.IntegerType @@ -26,81 +27,110 @@ import org.apache.spark.sql.types.IntegerType class ConstraintSuite extends SparkFunSuite { test("CHECK constraint toDDL") { - val con1 = Constraint.check( - "con1", - "id > 10", - ConstraintState.builder().enforced(true).validated(true).rely(true).build()) + val con1 = Constraint.check("con1") + .sql("id > 10") + .enforced(true) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() assert(con1.toDDL == "CONSTRAINT con1 CHECK id > 10 ENFORCED VALID RELY") - val con2 = Constraint.check( - "con2", + val con2 = Constraint.check("con2") + .predicate( new Predicate( "=", Array[Expression]( FieldReference(Seq("a", "b.c", "d")), - LiteralValue(1, IntegerType))), - ConstraintState.builder().enforced(false).validated(true).rely(true).build()) + LiteralValue(1, IntegerType)))) + .enforced(false) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() assert(con2.toDDL == "CONSTRAINT con2 CHECK a.`b.c`.d = 1 NOT ENFORCED VALID RELY") - val con3 = Constraint.check( - "con3", - "a.b.c <=> 1", - new Predicate( - "<=>", - Array[Expression]( - FieldReference(Seq("a", "b", "c")), - LiteralValue(1, IntegerType))), - ConstraintState.builder().enforced(false).validated(false).rely(false).build()) - assert(con3.toDDL == "CONSTRAINT con3 CHECK a.b.c <=> 1 NOT ENFORCED NOT VALID NORELY") + val con3 = Constraint.check("con3") + .sql("a.b.c <=> 1") + .predicate( + new Predicate( + "<=>", + Array[Expression]( + FieldReference(Seq("a", "b", "c")), + LiteralValue(1, IntegerType)))) + .enforced(false) + .validationStatus(ValidationStatus.INVALID) + .rely(false) + .build() + assert(con3.toDDL == "CONSTRAINT con3 CHECK a.b.c <=> 1 NOT ENFORCED INVALID NORELY") + + val con4 = Constraint.check("con4").sql("a = 1").build() + assert(con4.toDDL == "CONSTRAINT con4 CHECK a = 1 ENFORCED UNVALIDATED NORELY") } test("UNIQUE constraint toDDL") { val con1 = Constraint.unique( - "con1", - Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d"))), - ConstraintState.builder().enforced(false).validated(false).rely(true).build()) - assert(con1.toDDL == "CONSTRAINT con1 UNIQUE (a.b.c, d) NOT ENFORCED NOT VALID RELY") + "con1", + Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d")))) + .enforced(false) + .validationStatus(ValidationStatus.UNVALIDATED) + .rely(true) + .build() + assert(con1.toDDL == "CONSTRAINT con1 UNIQUE (a.b.c, d) NOT ENFORCED UNVALIDATED RELY") val con2 = Constraint.unique( - "con2", - Array[NamedReference](FieldReference(Seq("a.b", "x", "y")), FieldReference(Seq("d"))), - ConstraintState.builder().enforced(false).validated(true).rely(true).build()) + "con2", + Array[NamedReference](FieldReference(Seq("a.b", "x", "y")), FieldReference(Seq("d")))) + .enforced(false) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() assert(con2.toDDL == "CONSTRAINT con2 UNIQUE (`a.b`.x.y, d) NOT ENFORCED VALID RELY") } test("PRIMARY KEY constraint toDDL") { val pk1 = Constraint.primaryKey( - "pk1", - Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d"))), - ConstraintState.builder().enforced(true).validated(true).rely(true).build()) + "pk1", + Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d")))) + .enforced(true) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() assert(pk1.toDDL == "CONSTRAINT pk1 PRIMARY KEY (a.b.c, d) ENFORCED VALID RELY") val pk2 = Constraint.primaryKey( - "pk2", - Array[NamedReference](FieldReference(Seq("x.y", "z")), FieldReference(Seq("id"))), - ConstraintState.builder().enforced(false).validated(false).rely(false).build()) - assert(pk2.toDDL == "CONSTRAINT pk2 PRIMARY KEY (`x.y`.z, id) NOT ENFORCED NOT VALID NORELY") + "pk2", + Array[NamedReference](FieldReference(Seq("x.y", "z")), FieldReference(Seq("id")))) + .enforced(false) + .validationStatus(ValidationStatus.INVALID) + .rely(false) + .build() + assert(pk2.toDDL == "CONSTRAINT pk2 PRIMARY KEY (`x.y`.z, id) NOT ENFORCED INVALID NORELY") } test("FOREIGN KEY constraint toDDL") { val fk1 = Constraint.foreignKey( - "fk1", - Array[NamedReference](FieldReference(Seq("col1")), FieldReference(Seq("col2"))), - Identifier.of(Array("schema"), "table"), - Array[NamedReference](FieldReference(Seq("ref_col1")), FieldReference(Seq("ref_col2"))), - ConstraintState.builder().enforced(true).validated(true).rely(true).build()) + "fk1", + Array[NamedReference](FieldReference(Seq("col1")), FieldReference(Seq("col2"))), + Identifier.of(Array("schema"), "table"), + Array[NamedReference](FieldReference(Seq("ref_col1")), FieldReference(Seq("ref_col2")))) + .enforced(true) + .validationStatus(ValidationStatus.VALID) + .rely(true) + .build() assert(fk1.toDDL == "CONSTRAINT fk1 FOREIGN KEY (col1, col2) " + "REFERENCES schema.table (ref_col1, ref_col2) " + "ENFORCED VALID RELY") val fk2 = Constraint.foreignKey( - "fk2", - Array[NamedReference](FieldReference(Seq("x.y", "z"))), - Identifier.of(Array.empty[String], "other_table"), - Array[NamedReference](FieldReference(Seq("other_id"))), - ConstraintState.builder().enforced(false).validated(false).rely(false).build()) + "fk2", + Array[NamedReference](FieldReference(Seq("x.y", "z"))), + Identifier.of(Array.empty[String], "other_table"), + Array[NamedReference](FieldReference(Seq("other_id")))) + .enforced(false) + .validationStatus(ValidationStatus.INVALID) + .rely(false) + .build() assert(fk2.toDDL == "CONSTRAINT fk2 FOREIGN KEY (`x.y`.z) " + "REFERENCES other_table (other_id) " + - "NOT ENFORCED NOT VALID NORELY") + "NOT ENFORCED INVALID NORELY") } } From 12fe567b6d2f69478732c5ce4131577e95d20242 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 20 Mar 2025 14:54:23 -0700 Subject: [PATCH 26/65] fix compiling --- .../spark/sql/catalyst/expressions/constraints.scala | 9 ++++++++- .../sql/execution/command/v2/CheckConstraintSuite.scala | 2 +- .../command/v2/CreateTableConstraintSuite.scala | 2 +- .../sql/execution/command/v2/DropConstraintSuite.scala | 2 +- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index bc46d0ca4d85..8355b16e7c3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -38,7 +38,14 @@ case class CheckConstraint( def asConstraint: Constraint = { val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull - Constraint.check(name, sql, predicate) + Constraint + .check(name) + .sql(sql) + .predicate(predicate) + .rely(true) + .enforced(true) + .validationStatus(Constraint.ValidationStatus.VALID) + .build() } override protected def withNewChildInternal(newChild: Expression): Expression = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index 1dac4fd85380..d226bf5ca69a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.catalog.constraints.Constraint.Check +import org.apache.spark.sql.connector.catalog.constraints.Check import org.apache.spark.sql.execution.command.DDLCommandTestUtils class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala index 0a9cd948a9f4..bd78eb4d3c46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.connector.catalog.constraints.Constraint.Check +import org.apache.spark.sql.connector.catalog.constraints.Check import org.apache.spark.sql.execution.command.DDLCommandTestUtils class CreateTableConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala index 368bf71c53d9..f492e18a6e52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropConstraintSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.connector.catalog.constraints.Constraint.Check +import org.apache.spark.sql.connector.catalog.constraints.Check import org.apache.spark.sql.execution.command.DDLCommandTestUtils class DropConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { From 6c3086359c31b0c447151a43323cfa591ec98561 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 21 Mar 2025 17:32:41 -0700 Subject: [PATCH 27/65] refactor parser and fix tests --- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 6 ++++ .../sql/catalyst/parser/SqlBaseParser.g4 | 34 +++++++++++++++++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +-- .../catalyst/expressions/constraints.scala | 4 +++ .../sql/catalyst/parser/AstBuilder.scala | 34 +++++++++++++------ .../AlterTableAddConstraintParseSuite.scala | 8 ++--- .../command/v2/CheckConstraintSuite.scala | 8 +++-- 7 files changed, 76 insertions(+), 22 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index b868eea41b69..0975b4dc61f0 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -222,6 +222,7 @@ DROP: 'DROP'; ELSE: 'ELSE'; ELSEIF: 'ELSEIF'; END: 'END'; +ENFORCED: 'ENFORCED'; ESCAPE: 'ESCAPE'; ESCAPED: 'ESCAPED'; EVOLUTION: 'EVOLUTION'; @@ -290,6 +291,7 @@ ITEMS: 'ITEMS'; ITERATE: 'ITERATE'; JOIN: 'JOIN'; JSON: 'JSON'; +KEY: 'KEY'; KEYS: 'KEYS'; LANGUAGE: 'LANGUAGE'; LAST: 'LAST'; @@ -337,6 +339,8 @@ NOT: 'NOT'; NULL: 'NULL'; NULLS: 'NULLS'; NUMERIC: 'NUMERIC'; +NORELY: 'NORELY'; +NOVALIDATE: 'NOVALIDATE'; OF: 'OF'; OFFSET: 'OFFSET'; ON: 'ON'; @@ -376,6 +380,7 @@ RECURSIVE: 'RECURSIVE'; REDUCE: 'REDUCE'; REFERENCES: 'REFERENCES'; REFRESH: 'REFRESH'; +RELY: 'RELY'; RENAME: 'RENAME'; REPAIR: 'REPAIR'; REPEAT: 'REPEAT'; @@ -475,6 +480,7 @@ UPDATE: 'UPDATE'; USE: 'USE'; USER: 'USER'; USING: 'USING'; +VALIDATE: 'VALIDATE'; VALUE: 'VALUE'; VALUES: 'VALUES'; VARCHAR: 'VARCHAR'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 729f3bc7cd55..b04017d14126 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1522,11 +1522,36 @@ number ; constraintSpec - : CONSTRAINT constraintName=errorCapturingIdentifier constraintExpression + : constraintName? constraintExpression constraintCharacteristics? + ; + +constraintName + : CONSTRAINT name=errorCapturingIdentifier ; constraintExpression - : CHECK '(' booleanExpression ')' #checkConstraint + : checkConstraint + | uniqueConstraint + | foreignKeyConstraint + ; + +checkConstraint + : CHECK LEFT_PAREN (expr=booleanExpression) RIGHT_PAREN + ; + +uniqueConstraint + : UNIQUE identifierList #uniqueConstraintClause + | PRIMARY KEY identifierList #primaryKeyConstraintClause + ; + +foreignKeyConstraint + : FOREIGN KEY identifierList REFERENCES table=multipartIdentifier identifierList + ; + +constraintCharacteristics + : NOT? ENFORCED + | RELY + | NORELY ; alterColumnSpecList @@ -1686,6 +1711,7 @@ ansiNonReserved | DOUBLE | DROP | ELSEIF + | ENFORCED | ESCAPED | EVOLUTION | EXCHANGE @@ -1774,6 +1800,8 @@ ansiNonReserved | NANOSECONDS | NO | NONE + | NORELY + | NOVALIDATE | NULLS | NUMERIC | OF @@ -1805,6 +1833,7 @@ ansiNonReserved | RECOVER | REDUCE | REFRESH + | RELY | RENAME | REPAIR | REPEAT @@ -1888,6 +1917,7 @@ ansiNonReserved | UNTIL | UPDATE | USE + | VALIDATE | VALUE | VALUES | VARCHAR diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a9bbf25dd5c2..d48ea55a45a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1193,14 +1193,14 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString case addConstraint @ AddCheckConstraint(table: ResolvedTable, constraintExpr) => if (!constraintExpr.resolved) { - constraintExpr.failAnalysis( + constraintExpr.child.failAnalysis( errorClass = "INVALID_CHECK_CONSTRAINT.UNRESOLVED", messageParameters = Map.empty ) } if (!constraintExpr.deterministic) { - constraintExpr.failAnalysis( + constraintExpr.child.failAnalysis( errorClass = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", messageParameters = Map.empty ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index 8355b16e7c3b..28b616782cd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -28,6 +28,8 @@ trait ConstraintExpression extends Expression with Unevaluable { override def dataType: DataType = StringType def asConstraint: Constraint + + def withName(name: String): ConstraintExpression } case class CheckConstraint( @@ -50,6 +52,8 @@ case class CheckConstraint( override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) + + override def withName(name: String): ConstraintExpression = copy(name = name) } /* diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e0965036cc02..cd8d89ede1ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5241,17 +5241,29 @@ class AstBuilder extends DataTypeAstBuilder AlterTableCollation(table, visitCollationSpec(ctx.collationSpec())) } - override def visitConstraintSpec(ctx: ConstraintSpecContext): ConstraintExpression = { - ctx.constraintExpression() match { - case c: CheckConstraintContext => - CheckConstraint( - name = ctx.constraintName.getText, - sql = c.booleanExpression().getText, - child = expression(c.booleanExpression()) - ) - case other => - throw QueryParsingErrors.constraintNotSupportedError(ctx, other.getText) - } + override def visitConstraintSpec(ctx: ConstraintSpecContext): ConstraintExpression = + withOrigin(ctx) { + val name = visitConstraintName(ctx.constraintName()) + val expr = + visitConstraintExpression(ctx.constraintExpression()).asInstanceOf[ConstraintExpression] + if (name != null) { + expr.withName(name) + } else { + expr + } + } + + override def visitConstraintName(ctx: ConstraintNameContext): String = { + ctx.name.getText + } + + override def visitCheckConstraint(ctx: CheckConstraintContext): CheckConstraint = + withOrigin(ctx) { + CheckConstraint( + name = "", + sql = ctx.expr.getText, + child = expression(ctx.booleanExpression()) + ) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index 23bed2a1b827..890c6cd89b2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -37,7 +37,7 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes "ALTER TABLE ... ADD CONSTRAINT"), CheckConstraint( "c1", - "d > 0", + "d>0", GreaterThan(UnresolvedAttribute("d"), Literal(0)))) comparePlans(parsed, expected) } @@ -47,10 +47,10 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes """ |ALTER TABLE a.b.c ADD CONSTRAINT c1-c3 CHECK (d > 0) |""".stripMargin - val msg = intercept[ParseException] { + val e = intercept[ParseException] { parsePlan(sql) - }.getMessage - assert(msg.contains("Syntax error at or near '-'.")) + } + checkError(e, "INVALID_IDENTIFIER", "42602", Map("ident" -> "c1-c3")) } test("Add invalid check constraint expression") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index d226bf5ca69a..a08e1ca9f5e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -91,7 +91,8 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma val table = loadTable(catalog, "ns", "tbl") val constraint = getCheckConstraint(table) assert(constraint.name() == "c1") - assert(constraint.toDDL == "CHECK (from_json(j,'a INT').a>1)") + assert(constraint.toDDL == + "CONSTRAINT c1 CHECK from_json(j,'a INT').a>1 ENFORCED VALID RELY") assert(constraint.sql() == "from_json(j,'a INT').a>1") assert(constraint.predicate() == null) } @@ -106,7 +107,7 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma val table = loadTable(catalog, "ns", "tbl") val constraint = getCheckConstraint(table) assert(constraint.name() == "c1") - assert(constraint.toDDL == "CHECK (id>0)") + assert(constraint.toDDL == "CONSTRAINT c1 CHECK id>0 ENFORCED VALID RELY") } } @@ -125,7 +126,8 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma exception = error, condition = "CONSTRAINT_ALREADY_EXISTS", sqlState = "42710", - parameters = Map("constraintName" -> "abc", "oldConstraint" -> "CHECK (id>0)") + parameters = Map("constraintName" -> "abc", + "oldConstraint" -> "CONSTRAINT abc CHECK id>0 ENFORCED VALID RELY") ) } } From cdf55b5a1a04bf0a50aa2fea48d534ae51134a0c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 24 Mar 2025 14:48:19 -0700 Subject: [PATCH 28/65] new syntax; compiling version --- .../resources/error/error-conditions.json | 11 ++++ .../sql/catalyst/parser/SqlBaseParser.g4 | 25 +++++++-- .../spark/sql/errors/QueryParsingErrors.scala | 9 +++ .../catalyst/expressions/constraints.scala | 55 ++++++++++++++++--- .../sql/catalyst/parser/AstBuilder.scala | 48 ++++++++++++---- .../AlterTableAddConstraintParseSuite.scala | 7 ++- .../v1/CreateTableConstraintParseSuite.scala | 17 ++++-- 7 files changed, 138 insertions(+), 34 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index d55b89f742f4..bb24ed027c0c 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2331,6 +2331,11 @@ "message" : [ "It contains nondeterministic expression." ] + }, + "MISSING_NAME": { + "message": [ + "The check constraint must have a name." + ] } }, "sqlState": "42621" @@ -2360,6 +2365,12 @@ }, "sqlState" : "22022" }, + "INVALID_CONSTRAINT_CHARACTERISTICS": { + "message": [ + "Constraint characteristics [] are duplicated or conflict with each other." + ], + "sqlState": "42613" + }, "INVALID_CORRUPT_RECORD_TYPE" : { "message" : [ "The column for corrupt records must have the nullable STRING type, but got ." diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index b04017d14126..2e307592c151 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1522,7 +1522,7 @@ number ; constraintSpec - : constraintName? constraintExpression constraintCharacteristics? + : constraintName? constraintExpression constraintCharacteristic* ; constraintName @@ -1532,6 +1532,7 @@ constraintName constraintExpression : checkConstraint | uniqueConstraint + | primaryKeyConstraint | foreignKeyConstraint ; @@ -1540,17 +1541,29 @@ checkConstraint ; uniqueConstraint - : UNIQUE identifierList #uniqueConstraintClause - | PRIMARY KEY identifierList #primaryKeyConstraintClause + : UNIQUE identifierList + ; + +primaryKeyConstraint + : PRIMARY KEY identifierList ; foreignKeyConstraint : FOREIGN KEY identifierList REFERENCES table=multipartIdentifier identifierList ; -constraintCharacteristics - : NOT? ENFORCED - | RELY +constraintCharacteristic + : enforcedCharacteristic + | relyCharacteristic + ; + +enforcedCharacteristic + : ENFORCED + | NOT ENFORCED + ; + +relyCharacteristic + : RELY | NORELY ; diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index f13899d6e40a..e41cdfaaf1e7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -792,6 +792,15 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { new ParseException(errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", ctx) } + def invalidConstraintCharacteristics( + ctx: ParserRuleContext, + characteristics: String): Throwable = { + new ParseException( + errorClass = "INVALID_CONSTRAINT_CHARACTERISTICS", + messageParameters = Map("characteristics" -> characteristics), + ctx) + } + def constraintNotSupportedError(ctx: ParserRuleContext, constraint: String): Throwable = { new ParseException( errorClass = "UNSUPPORTED_FEATURE.CONSTRAINT_TYPE", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index 28b616782cd0..afcbc516243c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.constraints.Constraint @@ -29,23 +30,45 @@ trait ConstraintExpression extends Expression with Unevaluable { def asConstraint: Constraint - def withName(name: String): ConstraintExpression + def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic): ConstraintExpression + + def name: String + + def characteristic: ConstraintCharacteristic + + def defaultName: String + + def defaultConstraintCharacteristic: ConstraintCharacteristic +} + +case class ConstraintCharacteristic(enforced: Option[Boolean], rely: Option[Boolean]) + +object ConstraintCharacteristic { + val empty: ConstraintCharacteristic = ConstraintCharacteristic(None, None) } case class CheckConstraint( - name: String, - override val sql: String, - child: Expression) extends ConstraintExpression + child: Expression, + condition: String, + override val name: String = null, + override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) + extends ConstraintExpression with UnaryLike[Expression] { def asConstraint: Constraint = { val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull + val rely = characteristic.rely.getOrElse(defaultConstraintCharacteristic.rely.get) + val enforced = + characteristic.enforced.getOrElse(defaultConstraintCharacteristic.enforced.get) + val constraintName = if (name == null) defaultName else name Constraint - .check(name) - .sql(sql) + .check(constraintName) + .sql(condition) .predicate(predicate) - .rely(true) - .enforced(true) + .rely(rely) + .enforced(enforced) .validationStatus(Constraint.ValidationStatus.VALID) .build() } @@ -53,7 +76,21 @@ case class CheckConstraint( override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) - override def withName(name: String): ConstraintExpression = copy(name = name) + override def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic): ConstraintExpression = { + copy(name = name, characteristic = c) + } + + override def defaultName: String = + throw new AnalysisException( + errorClass = "INVALID_CHECK_CONSTRAINT.MISSING_NAME", + messageParameters = Map.empty) + + override def defaultConstraintCharacteristic: ConstraintCharacteristic = + ConstraintCharacteristic(enforced = Some(true), rely = Some(true)) + + override def sql: String = s"CONSTRAINT $name CHECK ($condition)" } /* diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd8d89ede1ee..f6371cbcf279 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5243,14 +5243,16 @@ class AstBuilder extends DataTypeAstBuilder override def visitConstraintSpec(ctx: ConstraintSpecContext): ConstraintExpression = withOrigin(ctx) { - val name = visitConstraintName(ctx.constraintName()) - val expr = - visitConstraintExpression(ctx.constraintExpression()).asInstanceOf[ConstraintExpression] - if (name != null) { - expr.withName(name) + val name = if (ctx.constraintName() != null) { + visitConstraintName(ctx.constraintName()) } else { - expr + null } + val constraintCharacteristic = visitConstraintCharacteristic(ctx) + val expr = + visitConstraintExpression(ctx.constraintExpression()).asInstanceOf[ConstraintExpression] + + expr.withNameAndCharacteristic(name, constraintCharacteristic) } override def visitConstraintName(ctx: ConstraintNameContext): String = { @@ -5259,11 +5261,35 @@ class AstBuilder extends DataTypeAstBuilder override def visitCheckConstraint(ctx: CheckConstraintContext): CheckConstraint = withOrigin(ctx) { - CheckConstraint( - name = "", - sql = ctx.expr.getText, - child = expression(ctx.booleanExpression()) - ) + CheckConstraint( + child = expression(ctx.booleanExpression()), + condition = ctx.expr.getText) + } + + private def visitConstraintCharacteristic( + ctx: ConstraintSpecContext): ConstraintCharacteristic = { + var enforcement: Option[String] = None + var rely: Option[String] = None + ctx.constraintCharacteristic().asScala.foreach { + case e if e.enforcedCharacteristic() != null => + val text = e.enforcedCharacteristic().getText.toUpperCase(Locale.ROOT) + if (enforcement.isDefined) { + val invalidCharacteristics = s"${enforcement.get}, $text" + throw QueryParsingErrors.invalidConstraintCharacteristics(ctx, invalidCharacteristics) + } else { + enforcement = Some(text) + } + + case r if r.relyCharacteristic() != null => + val text = r.relyCharacteristic().getText.toUpperCase(Locale.ROOT) + if (rely.isDefined) { + val invalidCharacteristics = s"${rely.get}, $text" + throw QueryParsingErrors.invalidConstraintCharacteristics(ctx, invalidCharacteristics) + } else { + rely = Some(text) + } + } + ConstraintCharacteristic(enforcement.map(_ == "ENFORCED"), rely.map(_ == "RELY")) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index 890c6cd89b2d..601d01f10523 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -36,9 +36,10 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes Seq("a", "b", "c"), "ALTER TABLE ... ADD CONSTRAINT"), CheckConstraint( - "c1", - "d>0", - GreaterThan(UnresolvedAttribute("d"), Literal(0)))) + child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), + condition = "d>0", + name = "c1" + )) comparePlans(parsed, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala index 8e681b23abe4..6f7e81f21d43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala @@ -49,17 +49,24 @@ class CreateTableConstraintParseSuite extends AnalysisTest with SharedSparkSessi test("Create table with one check constraint") { val constraintStr = "CONSTRAINT c1 CHECK (a > 0)" - val constraint = CheckConstraint("c1", "a>0", GreaterThan(UnresolvedAttribute("a"), Literal(0))) + val constraint = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a>0", + name = "c1") val constraints = Constraints(Seq(constraint)) verifyConstraints(constraintStr, constraints) } test("Create table with two check constraints") { val constraintStr = "CONSTRAINT c1 CHECK (a > 0) CONSTRAINT c2 CHECK (b = 'foo')" - val constraint1 = - CheckConstraint("c1", "a>0", GreaterThan(UnresolvedAttribute("a"), Literal(0))) - val constraint2 = - CheckConstraint("c2", "b='foo'", EqualTo(UnresolvedAttribute("b"), Literal("foo"))) + val constraint1 = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a>0", + name = "c1") + val constraint2 = CheckConstraint( + child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), + condition = "b='foo", + name = "c2") val constraints = Constraints(Seq(constraint1, constraint2)) verifyConstraints(constraintStr, constraints) } From c80e898ae9255fca581add84c48b5453f3945acf Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 24 Mar 2025 16:39:47 -0700 Subject: [PATCH 29/65] getOriginalText --- .../sql/catalyst/parser/AstBuilder.scala | 26 ++++++++++++++---- .../AlterTableAddConstraintParseSuite.scala | 27 +++++++++++++++++-- .../v1/CreateTableConstraintParseSuite.scala | 6 ++--- .../command/v2/CheckConstraintSuite.scala | 8 +++--- 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f6371cbcf279..b14d02673728 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -117,6 +117,23 @@ class AstBuilder extends DataTypeAstBuilder } } + /** + * Retrieves the original input text for a given parser context, preserving all whitespace and + * formatting. + * + * ANTLR's default getText method removes whitespace because lexer rules typically skip it. + * This utility method extracts the exact text from the original input stream, using token + * indices. + * + * @param ctx The parser context to retrieve original text from. + * @return The original input text, including all whitespaces and formatting. + */ + private def getOriginalText(ctx: ParserRuleContext): String = { + ctx.getStart.getInputStream.getText( + new Interval(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex) + ) + } + /** * Override the default behavior for all visit methods. This will only return a non-null result * when the context has only one child. This is done because there is no generic method to @@ -3946,9 +3963,7 @@ class AstBuilder extends DataTypeAstBuilder // use `Expression.sql` to avoid storing incorrect text caused by bugs in any expression's // `sql` method. Note: `exprCtx.getText` returns a string without spaces, so we need to // get the text from the underlying char stream instead. - val start = exprCtx.getStart.getStartIndex - val end = exprCtx.getStop.getStopIndex - val originalSQL = exprCtx.getStart.getInputStream.getText(new Interval(start, end)) + val originalSQL = getOriginalText(exprCtx) DefaultValueExpression(expr, originalSQL) } @@ -5261,9 +5276,10 @@ class AstBuilder extends DataTypeAstBuilder override def visitCheckConstraint(ctx: CheckConstraintContext): CheckConstraint = withOrigin(ctx) { + val condition = getOriginalText(ctx.expr) CheckConstraint( child = expression(ctx.booleanExpression()), - condition = ctx.expr.getText) + condition = condition) } private def visitConstraintCharacteristic( @@ -5272,7 +5288,7 @@ class AstBuilder extends DataTypeAstBuilder var rely: Option[String] = None ctx.constraintCharacteristic().asScala.foreach { case e if e.enforcedCharacteristic() != null => - val text = e.enforcedCharacteristic().getText.toUpperCase(Locale.ROOT) + val text = getOriginalText(e.enforcedCharacteristic()).toUpperCase(Locale.ROOT) if (enforcement.isDefined) { val invalidCharacteristics = s"${enforcement.get}, $text" throw QueryParsingErrors.invalidConstraintCharacteristics(ctx, invalidCharacteristics) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index 601d01f10523..28145c230fa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedTable} -import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, GreaterThan, Literal} +import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, ConstraintCharacteristic, GreaterThan, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.AddCheckConstraint @@ -37,7 +37,7 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes "ALTER TABLE ... ADD CONSTRAINT"), CheckConstraint( child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), - condition = "d>0", + condition = "d > 0", name = "c1" )) comparePlans(parsed, expected) @@ -64,4 +64,27 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes }.getMessage assert(msg.contains("Syntax error at or near ')'")) } + + test("Add valid constraint characteristic") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) NOT ENFORCED + |""".stripMargin + val parsed = parsePlan(sql) + val characteristic = ConstraintCharacteristic( + enforced = Some(false), + rely = None + ) + val expected = AddCheckConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + CheckConstraint( + child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), + condition = "d > 0", + name = "c1", + characteristic = characteristic + )) + comparePlans(parsed, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala index 6f7e81f21d43..6349651f443f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala @@ -51,7 +51,7 @@ class CreateTableConstraintParseSuite extends AnalysisTest with SharedSparkSessi val constraintStr = "CONSTRAINT c1 CHECK (a > 0)" val constraint = CheckConstraint( child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a>0", + condition = "a > 0", name = "c1") val constraints = Constraints(Seq(constraint)) verifyConstraints(constraintStr, constraints) @@ -61,11 +61,11 @@ class CreateTableConstraintParseSuite extends AnalysisTest with SharedSparkSessi val constraintStr = "CONSTRAINT c1 CHECK (a > 0) CONSTRAINT c2 CHECK (b = 'foo')" val constraint1 = CheckConstraint( child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a>0", + condition = "a > 0", name = "c1") val constraint2 = CheckConstraint( child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), - condition = "b='foo", + condition = "b = 'foo'", name = "c2") val constraints = Constraints(Seq(constraint1, constraint2)) verifyConstraints(constraintStr, constraints) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index a08e1ca9f5e3..900779f2d1c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -92,8 +92,8 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma val constraint = getCheckConstraint(table) assert(constraint.name() == "c1") assert(constraint.toDDL == - "CONSTRAINT c1 CHECK from_json(j,'a INT').a>1 ENFORCED VALID RELY") - assert(constraint.sql() == "from_json(j,'a INT').a>1") + "CONSTRAINT c1 CHECK from_json(j, 'a INT').a > 1 ENFORCED VALID RELY") + assert(constraint.sql() == "from_json(j, 'a INT').a > 1") assert(constraint.predicate() == null) } } @@ -107,7 +107,7 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma val table = loadTable(catalog, "ns", "tbl") val constraint = getCheckConstraint(table) assert(constraint.name() == "c1") - assert(constraint.toDDL == "CONSTRAINT c1 CHECK id>0 ENFORCED VALID RELY") + assert(constraint.toDDL == "CONSTRAINT c1 CHECK id > 0 ENFORCED VALID RELY") } } @@ -127,7 +127,7 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma condition = "CONSTRAINT_ALREADY_EXISTS", sqlState = "42710", parameters = Map("constraintName" -> "abc", - "oldConstraint" -> "CONSTRAINT abc CHECK id>0 ENFORCED VALID RELY") + "oldConstraint" -> "CONSTRAINT abc CHECK id > 0 ENFORCED VALID RELY") ) } } From 41dcc23f1ab6c7cc711cd8cb33531a6846646c49 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 24 Mar 2025 18:01:04 -0700 Subject: [PATCH 30/65] add more tests for characteristic --- .../sql/catalyst/parser/AstBuilder.scala | 6 +- .../AlterTableAddConstraintParseSuite.scala | 84 ++++++++++++++----- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b14d02673728..9c85a1cc6e32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5291,7 +5291,8 @@ class AstBuilder extends DataTypeAstBuilder val text = getOriginalText(e.enforcedCharacteristic()).toUpperCase(Locale.ROOT) if (enforcement.isDefined) { val invalidCharacteristics = s"${enforcement.get}, $text" - throw QueryParsingErrors.invalidConstraintCharacteristics(ctx, invalidCharacteristics) + throw QueryParsingErrors.invalidConstraintCharacteristics( + e.enforcedCharacteristic(), invalidCharacteristics) } else { enforcement = Some(text) } @@ -5300,7 +5301,8 @@ class AstBuilder extends DataTypeAstBuilder val text = r.relyCharacteristic().getText.toUpperCase(Locale.ROOT) if (rely.isDefined) { val invalidCharacteristics = s"${rely.get}, $text" - throw QueryParsingErrors.invalidConstraintCharacteristics(ctx, invalidCharacteristics) + throw QueryParsingErrors.invalidConstraintCharacteristics( + r.relyCharacteristic(), invalidCharacteristics) } else { rely = Some(text) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index 28145c230fa5..a4a2b160f5f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -65,26 +65,70 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes assert(msg.contains("Syntax error at or near ')'")) } - test("Add valid constraint characteristic") { - val sql = - """ - |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) NOT ENFORCED - |""".stripMargin - val parsed = parsePlan(sql) - val characteristic = ConstraintCharacteristic( - enforced = Some(false), - rely = None + test("Add check constraint with valid characteristic") { + val combinations = Seq( + ("", "", ConstraintCharacteristic(enforced = None, rely = None)), + ("ENFORCED", "", ConstraintCharacteristic(enforced = Some(true), rely = None)), + ("NOT ENFORCED", "", ConstraintCharacteristic(enforced = Some(false), rely = None)), + ("", "RELY", ConstraintCharacteristic(enforced = None, rely = Some(true))), + ("", "NORELY", ConstraintCharacteristic(enforced = None, rely = Some(false))), + ("ENFORCED", "RELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(true))), + ("ENFORCED", "NORELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(false))), + ("NOT ENFORCED", "RELY", ConstraintCharacteristic(enforced = Some(false), rely = Some(true))), + ("NOT ENFORCED", "NORELY", + ConstraintCharacteristic(enforced = Some(false), rely = Some(false))) ) - val expected = AddCheckConstraint( - UnresolvedTable( - Seq("a", "b", "c"), - "ALTER TABLE ... ADD CONSTRAINT"), - CheckConstraint( - child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), - condition = "d > 0", - name = "c1", - characteristic = characteristic - )) - comparePlans(parsed, expected) + + combinations.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddCheckConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + CheckConstraint( + child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), + condition = "d > 0", + name = "c1", + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + + test("Add check constraint with invalid characteristic") { + val combinations = Seq( + ("ENFORCED", "ENFORCED"), + ("ENFORCED", "NOT ENFORCED"), + ("NOT ENFORCED", "ENFORCED"), + ("NOT ENFORCED", "NOT ENFORCED"), + ("RELY", "RELY"), + ("RELY", "NORELY"), + ("NORELY", "RELY"), + ("NORELY", "NORELY") + ) + + combinations.foreach { case (characteristic1, characteristic2) => + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $characteristic1 $characteristic2" + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (d > 0) $characteristic1 $characteristic2", + start = 22, + stop = 50 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } } } From 8bc6c329e0695c9918d76a8722a1472112cbdd05 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 24 Mar 2025 22:40:22 -0700 Subject: [PATCH 31/65] add tests in CreateTableConstraintParseSuite --- .../AlterTableAddConstraintParseSuite.scala | 22 ++------ .../command/ConstraintParseSuiteBase.scala | 50 +++++++++++++++++++ .../CreateTableConstraintParseSuite.scala | 39 +++++++++++++-- 3 files changed, 89 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/command/{v1 => }/CreateTableConstraintParseSuite.scala (63%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index a4a2b160f5f9..ce6e86842346 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -16,14 +16,13 @@ */ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedTable} -import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, ConstraintCharacteristic, GreaterThan, Literal} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedTable} +import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, GreaterThan, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.AddCheckConstraint -import org.apache.spark.sql.test.SharedSparkSession -class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSession { +class AlterTableAddConstraintParseSuite extends ConstraintParseSuiteBase { test("Add check constraint") { val sql = @@ -66,20 +65,7 @@ class AlterTableAddConstraintParseSuite extends AnalysisTest with SharedSparkSes } test("Add check constraint with valid characteristic") { - val combinations = Seq( - ("", "", ConstraintCharacteristic(enforced = None, rely = None)), - ("ENFORCED", "", ConstraintCharacteristic(enforced = Some(true), rely = None)), - ("NOT ENFORCED", "", ConstraintCharacteristic(enforced = Some(false), rely = None)), - ("", "RELY", ConstraintCharacteristic(enforced = None, rely = Some(true))), - ("", "NORELY", ConstraintCharacteristic(enforced = None, rely = Some(false))), - ("ENFORCED", "RELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(true))), - ("ENFORCED", "NORELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(false))), - ("NOT ENFORCED", "RELY", ConstraintCharacteristic(enforced = Some(false), rely = Some(true))), - ("NOT ENFORCED", "NORELY", - ConstraintCharacteristic(enforced = Some(false), rely = Some(false))) - ) - - combinations.foreach { case (enforcedStr, relyStr, characteristic) => + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => val sql = s""" |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $enforcedStr $relyStr diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala new file mode 100644 index 000000000000..c870399453af --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.expressions.ConstraintCharacteristic +import org.apache.spark.sql.test.SharedSparkSession + +abstract class ConstraintParseSuiteBase extends AnalysisTest with SharedSparkSession { + protected val validConstraintCharacteristics = Seq( + ("", "", ConstraintCharacteristic(enforced = None, rely = None)), + ("ENFORCED", "", ConstraintCharacteristic(enforced = Some(true), rely = None)), + ("NOT ENFORCED", "", ConstraintCharacteristic(enforced = Some(false), rely = None)), + ("", "RELY", ConstraintCharacteristic(enforced = None, rely = Some(true))), + ("", "NORELY", ConstraintCharacteristic(enforced = None, rely = Some(false))), + ("ENFORCED", "RELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(true))), + ("ENFORCED", "NORELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(false))), + ("NOT ENFORCED", "RELY", + ConstraintCharacteristic(enforced = Some(false), rely = Some(true))), + ("NOT ENFORCED", "NORELY", + ConstraintCharacteristic(enforced = Some(false), rely = Some(false))) + ) + + protected val invalidConstraintCharacteristics = Seq( + ("ENFORCED", "ENFORCED"), + ("ENFORCED", "NOT ENFORCED"), + ("NOT ENFORCED", "ENFORCED"), + ("NOT ENFORCED", "NOT ENFORCED"), + ("RELY", "RELY"), + ("RELY", "NORELY"), + ("NORELY", "RELY"), + ("NORELY", "NORELY") + ) + + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala similarity index 63% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala index 6349651f443f..0d8d98f1f20f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.command.v1 +package org.apache.spark.sql.execution.command -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedIdentifier} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, Constraints, EqualTo, GreaterThan, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} -import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType} -class CreateTableConstraintParseSuite extends AnalysisTest with SharedSparkSession { +class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { val createTablePrefix = "CREATE TABLE t (a INT, b STRING) USING parquet" val tableId = UnresolvedIdentifier(Seq("t")) val columns = Seq( @@ -70,4 +70,35 @@ class CreateTableConstraintParseSuite extends AnalysisTest with SharedSparkSessi val constraints = Constraints(Seq(constraint1, constraint2)) verifyConstraints(constraintStr, constraints) } + + test("Create table with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $enforcedStr $relyStr" + val constraint = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a > 0", + name = "c1", + characteristic = characteristic) + val constraints = Constraints(Seq(constraint)) + verifyConstraints(constraintStr, constraints) + } + } + + test("Create table with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2", + start = 47, + stop = 75 + characteristic1.length + characteristic2.length + ) + checkError( + exception = intercept[ParseException] { + parsePlan(s"$createTablePrefix $constraintStr") + }, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } } From 24e3bf7c3af92e91a7bc768deb09c77b4cab4ac4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 25 Mar 2025 10:31:21 -0700 Subject: [PATCH 32/65] add tests for ConstraintCharacteristics in CheckConstraintSuite --- .../command/v2/CheckConstraintSuite.scala | 49 ++++++++++++++----- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index 900779f2d1c5..329b6840d363 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -78,10 +78,7 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma assert(table.constraints.length == 1) assert(table.constraints.head.isInstanceOf[Check]) table.constraints.head.asInstanceOf[Check] - val constraint = table.constraints.head.asInstanceOf[Check] - assert(constraint.rely()) - assert(constraint.enforced()) - constraint + table.constraints.head.asInstanceOf[Check] } test("Predicate should be null if it can't be converted to V2 predicate") { @@ -98,16 +95,42 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma } } - test("Add check constraint") { - withNamespaceAndTable("ns", "tbl", catalog) { t => - sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") - assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + val validConstraintCharacteristics = Seq( + ("", "ENFORCED VALID RELY"), + ("NOT ENFORCED", "NOT ENFORCED VALID RELY"), + ("NOT ENFORCED NORELY", "NOT ENFORCED VALID NORELY"), + ("NORELY NOT ENFORCED", "NOT ENFORCED VALID NORELY"), + ("NORELY", "ENFORCED VALID NORELY"), + ("NOT ENFORCED RELY", "NOT ENFORCED VALID RELY"), + ("RELY NOT ENFORCED", "NOT ENFORCED VALID RELY"), + ("RELY", "ENFORCED VALID RELY") + ) - sql(s"ALTER TABLE $t ADD CONSTRAINT c1 CHECK (id > 0)") - val table = loadTable(catalog, "ns", "tbl") - val constraint = getCheckConstraint(table) - assert(constraint.name() == "c1") - assert(constraint.toDDL == "CONSTRAINT c1 CHECK id > 0 ENFORCED VALID RELY") + test("Create table with check constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + val constraintStr = s"CONSTRAINT c1 CHECK (id > 0) $characteristic" + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing $constraintStr") + val table = loadTable(catalog, "ns", "tbl") + val constraint = getCheckConstraint(table) + assert(constraint.name() == "c1") + assert(constraint.toDDL == s"CONSTRAINT c1 CHECK id > 0 $expectedDDL") + } + } + } + + test("Alter table add check constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT c1 CHECK (id > 0) $characteristic") + val table = loadTable(catalog, "ns", "tbl") + val constraint = getCheckConstraint(table) + assert(constraint.name() == "c1") + assert(constraint.toDDL == s"CONSTRAINT c1 CHECK id > 0 $expectedDDL") + } } } From 252332e87c8b9289c7d8bd92cf2037d9bec99257 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 25 Mar 2025 14:12:03 -0700 Subject: [PATCH 33/65] save for now --- .../sql/catalyst/parser/SqlBaseParser.g4 | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 49 ++++++++++++------- .../spark/sql/execution/SparkSqlParser.scala | 2 +- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 2e307592c151..543f5674ddf4 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -562,7 +562,6 @@ createTableClauses locationSpec | commentSpec | collationSpec | - constraintSpec | (TBLPROPERTIES tableProps=propertyList))* ; @@ -1352,6 +1351,7 @@ colDefinitionOption | defaultExpression | generationExpression | commentSpec + | constraintSpec ; generationExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9c85a1cc6e32..783a50c38248 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3861,23 +3861,26 @@ class AstBuilder extends DataTypeAstBuilder * Create top level table schema. */ protected def createSchema(ctx: ColDefinitionListContext): StructType = { - val columns = Option(ctx).toArray.flatMap(visitColDefinitionList) - StructType(columns.map(_.toV1Column)) + StructType(Option(ctx).toArray.flatMap{ c => + val (cols, _) = visitColDefinitionList(c) + cols.map(_.toV1Column) + }) } /** * Get CREATE TABLE column definitions. */ override def visitColDefinitionList( - ctx: ColDefinitionListContext): Seq[ColumnDefinition] = withOrigin(ctx) { - ctx.colDefinition().asScala.map(visitColDefinition).toSeq + ctx: ColDefinitionListContext): ColumnDefinitionList = withOrigin(ctx) { + val (colDefs, constraints) = ctx.colDefinition().asScala.map(visitColDefinition).toSeq.unzip + (colDefs, constraints.flatten) } /** * Get a CREATE TABLE column definition. */ override def visitColDefinition( - ctx: ColDefinitionContext): ColumnDefinition = withOrigin(ctx) { + ctx: ColDefinitionContext): ColumnAndConstraint = withOrigin(ctx) { import ctx._ val name: String = colName.getText @@ -3886,6 +3889,7 @@ class AstBuilder extends DataTypeAstBuilder var defaultExpression: Option[DefaultExpressionContext] = None var generationExpression: Option[GenerationExpressionContext] = None var commentSpec: Option[CommentSpecContext] = None + var constraintSpec: Option[ConstraintSpecContext] = None ctx.colDefinitionOption().asScala.foreach { option => if (option.NULL != null) { blockBang(option.errorCapturingNot) @@ -3919,10 +3923,16 @@ class AstBuilder extends DataTypeAstBuilder } commentSpec = Some(spec) } + Option(option.constraintSpec()).foreach { spec => + if (constraintSpec.isDefined) { + throw QueryParsingErrors.duplicateTableColumnDescriptor( + option, name, "CONSTRAINT") + } + } } val dataType = typedVisit[DataType](ctx.dataType) - ColumnDefinition( + val columnDef = ColumnDefinition( name = name, dataType = dataType, nullable = nullable, @@ -3935,6 +3945,8 @@ class AstBuilder extends DataTypeAstBuilder case ctx: IdentityColumnContext => visitIdentityColumn(ctx, dataType) } ) + val constraint = constraintSpec.map(visitConstraintSpec) + (columnDef, constraint) } /** @@ -4153,8 +4165,11 @@ class AstBuilder extends DataTypeAstBuilder */ type TableClauses = ( Seq[Transform], Seq[ColumnDefinition], Option[BucketSpec], Map[String, String], OptionList, - Option[String], Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec], - Constraints) + Option[String], Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec]) + + type ColumnAndConstraint = (ColumnDefinition, Option[ConstraintExpression]) + + type ColumnDefinitionList = (Seq[ColumnDefinition], Seq[ConstraintExpression]) /** * Validate a create table statement and return the [[TableIdentifier]]. @@ -4657,10 +4672,8 @@ class AstBuilder extends DataTypeAstBuilder } } - val constraints = Constraints(ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq) - (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, - collation, serdeInfo, clusterBySpec, constraints) + collation, serdeInfo, clusterBySpec) } protected def getSerdeInfo( @@ -4732,10 +4745,11 @@ class AstBuilder extends DataTypeAstBuilder val (identifierContext, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - val columns = Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse(Nil) + val (columns, colConstraints) = + Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse((Nil, Nil)) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) val (partTransforms, partCols, bucketSpec, properties, options, location, comment, - collation, serdeInfo, clusterBySpec, constraints) = + collation, serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) if (provider.isDefined && serdeInfo.isDefined) { @@ -4754,7 +4768,7 @@ class AstBuilder extends DataTypeAstBuilder clusterBySpec.map(_.asTransform) val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external, constraints) + collation, serdeInfo, external, Constraints(colConstraints)) Option(ctx.query).map(plan) match { case Some(_) if columns.nonEmpty => @@ -4814,8 +4828,9 @@ class AstBuilder extends DataTypeAstBuilder override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) { val orCreate = ctx.replaceTableHeader().CREATE() != null val (partTransforms, partCols, bucketSpec, properties, options, location, comment, collation, - serdeInfo, clusterBySpec, constraints) = visitCreateTableClauses(ctx.createTableClauses()) - val columns = Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse(Nil) + serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) + val (columns, colConstraints) = + Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse((Nil, Nil)) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) if (provider.isDefined && serdeInfo.isDefined) { @@ -4828,7 +4843,7 @@ class AstBuilder extends DataTypeAstBuilder clusterBySpec.map(_.asTransform) val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external = false, constraints) + collation, serdeInfo, external = false, Constraints(colConstraints)) Option(ctx.query).map(plan) match { case Some(_) if columns.nonEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index efb1602a2b69..8859b7b421b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -361,7 +361,7 @@ class SparkSqlAstBuilder extends AstBuilder { invalidStatement("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) } - val (_, _, _, _, options, location, _, _, _, _, _) = + val (_, _, _, _, options, location, _, _, _, _) = visitCreateTableClauses(ctx.createTableClauses()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText).getOrElse( throw QueryParsingErrors.createTempTableNotSpecifyProviderError(ctx)) From 462de00bb45adbce35b343740ae09aca4499ec2d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 25 Mar 2025 15:52:43 -0700 Subject: [PATCH 34/65] fix create table constraint syntax --- .../sql/catalyst/parser/SqlBaseParser.g4 | 8 +++-- .../sql/catalyst/parser/AstBuilder.scala | 32 +++++++++++++++---- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 543f5674ddf4..876662f2ba6d 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -198,7 +198,7 @@ statement (RESTRICT | CASCADE)? #dropNamespace | SHOW namespaces ((FROM | IN) multipartIdentifier)? (LIKE? pattern=stringLit)? #showNamespaces - | createTableHeader (LEFT_PAREN colDefinitionList RIGHT_PAREN)? tableProvider? + | createTableHeader (LEFT_PAREN colDefinitionList constraintListWithLeadingComma? RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #createTable | CREATE TABLE (IF errorCapturingNot EXISTS)? target=tableIdentifier @@ -208,7 +208,7 @@ statement createFileFormat | locationSpec | (TBLPROPERTIES tableProps=propertyList))* #createTableLike - | replaceTableHeader (LEFT_PAREN colDefinitionList RIGHT_PAREN)? tableProvider? + | replaceTableHeader (LEFT_PAREN colDefinitionList constraintListWithLeadingComma? RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #replaceTable | ANALYZE TABLE identifierReference partitionSpec? COMPUTE STATISTICS @@ -1521,6 +1521,10 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; +constraintListWithLeadingComma + : COMMA constraintSpec (COMMA constraintSpec)* + ; + constraintSpec : constraintName? constraintExpression constraintCharacteristic* ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 783a50c38248..61df9f96e137 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4711,6 +4711,18 @@ class AstBuilder extends DataTypeAstBuilder } } + private def visitColumnDefinitionList( + colCtx: ColDefinitionListContext, + constraintListContext: ConstraintListWithLeadingCommaContext): ColumnDefinitionList = { + val (columns, colConstraints) = Option(colCtx).map(visitColDefinitionList).getOrElse((Nil, Nil)) + + val tableConstraints = Option(constraintListContext) + .map(visitConstraintListWithLeadingComma) + .getOrElse(Nil) + + (columns, colConstraints ++ tableConstraints) + } + /** * Create a table, returning a [[CreateTable]] or [[CreateTableAsSelect]] logical plan. * @@ -4745,8 +4757,9 @@ class AstBuilder extends DataTypeAstBuilder val (identifierContext, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - val (columns, colConstraints) = - Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse((Nil, Nil)) + val (columns, constraints) = + visitColumnDefinitionList(ctx.colDefinitionList(), ctx.constraintListWithLeadingComma()) + val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) val (partTransforms, partCols, bucketSpec, properties, options, location, comment, collation, serdeInfo, clusterBySpec) = @@ -4768,7 +4781,7 @@ class AstBuilder extends DataTypeAstBuilder clusterBySpec.map(_.asTransform) val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external, Constraints(colConstraints)) + collation, serdeInfo, external, Constraints(constraints)) Option(ctx.query).map(plan) match { case Some(_) if columns.nonEmpty => @@ -4829,8 +4842,8 @@ class AstBuilder extends DataTypeAstBuilder val orCreate = ctx.replaceTableHeader().CREATE() != null val (partTransforms, partCols, bucketSpec, properties, options, location, comment, collation, serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) - val (columns, colConstraints) = - Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse((Nil, Nil)) + val (columns, constraints) = + visitColumnDefinitionList(ctx.colDefinitionList(), ctx.constraintListWithLeadingComma()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) if (provider.isDefined && serdeInfo.isDefined) { @@ -4843,7 +4856,7 @@ class AstBuilder extends DataTypeAstBuilder clusterBySpec.map(_.asTransform) val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external = false, Constraints(colConstraints)) + collation, serdeInfo, external = false, Constraints(constraints)) Option(ctx.query).map(plan) match { case Some(_) if columns.nonEmpty => @@ -5325,8 +5338,13 @@ class AstBuilder extends DataTypeAstBuilder ConstraintCharacteristic(enforcement.map(_ == "ENFORCED"), rely.map(_ == "RELY")) } + override def visitConstraintListWithLeadingComma( + ctx: ConstraintListWithLeadingCommaContext): Seq[ConstraintExpression] = { + ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq + } + /** - * Parse a [[AddCheckConstraint]] command. + * Parse an [[AlterTableCommand]] with table constraint. * * For example: * {{{ From 60f2d0bdfd3411de7dca2ec02937bb71414f31f4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 25 Mar 2025 21:36:57 -0700 Subject: [PATCH 35/65] refactor syntax --- .../sql/catalyst/parser/SqlBaseParser.g4 | 18 ++++--- .../sql/catalyst/parser/AstBuilder.scala | 47 +++++++++---------- .../spark/sql/execution/SparkSqlParser.scala | 2 +- .../CreateTableConstraintParseSuite.scala | 6 ++- 4 files changed, 39 insertions(+), 34 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 876662f2ba6d..02aca9105453 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -198,7 +198,7 @@ statement (RESTRICT | CASCADE)? #dropNamespace | SHOW namespaces ((FROM | IN) multipartIdentifier)? (LIKE? pattern=stringLit)? #showNamespaces - | createTableHeader (LEFT_PAREN colDefinitionList constraintListWithLeadingComma? RIGHT_PAREN)? tableProvider? + | createTableHeader (LEFT_PAREN tableElementList RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #createTable | CREATE TABLE (IF errorCapturingNot EXISTS)? target=tableIdentifier @@ -208,7 +208,7 @@ statement createFileFormat | locationSpec | (TBLPROPERTIES tableProps=propertyList))* #createTableLike - | replaceTableHeader (LEFT_PAREN colDefinitionList constraintListWithLeadingComma? RIGHT_PAREN)? tableProvider? + | replaceTableHeader (LEFT_PAREN tableElementList RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #replaceTable | ANALYZE TABLE identifierReference partitionSpec? COMPUTE STATISTICS @@ -1293,7 +1293,6 @@ type | INTERVAL | VARIANT | ARRAY | STRUCT | MAP - | unsupportedType=identifier ; dataType @@ -1338,6 +1337,15 @@ colType : colName=errorCapturingIdentifier dataType (errorCapturingNot NULL)? commentSpec? ; +tableElementList + : tableElement (COMMA tableElement)* + ; + +tableElement + : colDefinition + | constraintSpec + ; + colDefinitionList : colDefinition (COMMA colDefinition)* ; @@ -1521,10 +1529,6 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; -constraintListWithLeadingComma - : COMMA constraintSpec (COMMA constraintSpec)* - ; - constraintSpec : constraintName? constraintExpression constraintCharacteristic* ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 61df9f96e137..2e7c0b75ddda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3860,18 +3860,16 @@ class AstBuilder extends DataTypeAstBuilder /** * Create top level table schema. */ - protected def createSchema(ctx: ColDefinitionListContext): StructType = { - StructType(Option(ctx).toArray.flatMap{ c => - val (cols, _) = visitColDefinitionList(c) - cols.map(_.toV1Column) - }) + protected def createSchema(ctx: TableElementListContext): StructType = { + val (cols, _) = visitTableElementList(ctx) + StructType(cols.map(_.toV1Column)) } /** * Get CREATE TABLE column definitions. */ override def visitColDefinitionList( - ctx: ColDefinitionListContext): ColumnDefinitionList = withOrigin(ctx) { + ctx: ColDefinitionListContext): TableElementList = withOrigin(ctx) { val (colDefs, constraints) = ctx.colDefinition().asScala.map(visitColDefinition).toSeq.unzip (colDefs, constraints.flatten) } @@ -4169,7 +4167,7 @@ class AstBuilder extends DataTypeAstBuilder type ColumnAndConstraint = (ColumnDefinition, Option[ConstraintExpression]) - type ColumnDefinitionList = (Seq[ColumnDefinition], Seq[ConstraintExpression]) + type TableElementList = (Seq[ColumnDefinition], Seq[ConstraintExpression]) /** * Validate a create table statement and return the [[TableIdentifier]]. @@ -4711,16 +4709,24 @@ class AstBuilder extends DataTypeAstBuilder } } - private def visitColumnDefinitionList( - colCtx: ColDefinitionListContext, - constraintListContext: ConstraintListWithLeadingCommaContext): ColumnDefinitionList = { - val (columns, colConstraints) = Option(colCtx).map(visitColDefinitionList).getOrElse((Nil, Nil)) + override def visitTableElementList(ctx: TableElementListContext): TableElementList = { + if (ctx == null) { + return (Nil, Nil) + } + val columnDefs = new ArrayBuffer[ColumnDefinition]() + val constraints = new ArrayBuffer[ConstraintExpression]() - val tableConstraints = Option(constraintListContext) - .map(visitConstraintListWithLeadingComma) - .getOrElse(Nil) + ctx.tableElement().asScala.foreach { element => + if (element.constraintSpec() != null) { + constraints += visitConstraintSpec(element.constraintSpec()) + } else { + val (colDef, constraintOpt) = visitColDefinition(element.colDefinition()) + columnDefs += colDef + constraintOpt.foreach(constraints += _) + } + } - (columns, colConstraints ++ tableConstraints) + (columnDefs.toSeq, constraints.toSeq) } /** @@ -4757,8 +4763,7 @@ class AstBuilder extends DataTypeAstBuilder val (identifierContext, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - val (columns, constraints) = - visitColumnDefinitionList(ctx.colDefinitionList(), ctx.constraintListWithLeadingComma()) + val (columns, constraints) = visitTableElementList(ctx.tableElementList()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) val (partTransforms, partCols, bucketSpec, properties, options, location, comment, @@ -4842,8 +4847,7 @@ class AstBuilder extends DataTypeAstBuilder val orCreate = ctx.replaceTableHeader().CREATE() != null val (partTransforms, partCols, bucketSpec, properties, options, location, comment, collation, serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) - val (columns, constraints) = - visitColumnDefinitionList(ctx.colDefinitionList(), ctx.constraintListWithLeadingComma()) + val (columns, constraints) = visitTableElementList(ctx.tableElementList()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) if (provider.isDefined && serdeInfo.isDefined) { @@ -5338,11 +5342,6 @@ class AstBuilder extends DataTypeAstBuilder ConstraintCharacteristic(enforcement.map(_ == "ENFORCED"), rely.map(_ == "RELY")) } - override def visitConstraintListWithLeadingComma( - ctx: ConstraintListWithLeadingCommaContext): Seq[ConstraintExpression] = { - ctx.constraintSpec().asScala.map(visitConstraintSpec).toSeq - } - /** * Parse an [[AlterTableCommand]] with table constraint. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8859b7b421b3..b40cb82e9cc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -365,7 +365,7 @@ class SparkSqlAstBuilder extends AstBuilder { visitCreateTableClauses(ctx.createTableClauses()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText).getOrElse( throw QueryParsingErrors.createTempTableNotSpecifyProviderError(ctx)) - val schema = Option(ctx.colDefinitionList()).map(createSchema) + val schema = Option(ctx.tableElementList()).map(createSchema) logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + "CREATE TEMPORARY VIEW ... USING ... instead") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala index 0d8d98f1f20f..7169f4cf084d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTabl import org.apache.spark.sql.types.{IntegerType, StringType} class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { - val createTablePrefix = "CREATE TABLE t (a INT, b STRING) USING parquet" + val createTablePrefix = "CREATE TABLE t (a INT, b STRING" + val createTableSuffix = ") USING parquet" val tableId = UnresolvedIdentifier(Seq("t")) val columns = Seq( ColumnDefinition("a", IntegerType), @@ -36,7 +37,8 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { val sql = s""" |$createTablePrefix - |$constraintStr + |, $constraintStr + |$createTableSuffix |""".stripMargin val parsed = parsePlan(sql) From ff3133adecc883895fd26a9ffcc497d717d0e406 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Mar 2025 14:19:32 -0700 Subject: [PATCH 36/65] revise syntax; fix test failures --- .../sql/catalyst/parser/SqlBaseParser.g4 | 11 ++--- .../sql/catalyst/parser/AstBuilder.scala | 9 ++--- .../CreateTableConstraintParseSuite.scala | 40 ++++++++++++++----- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 02aca9105453..1fab9ac455b4 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1293,6 +1293,7 @@ type | INTERVAL | VARIANT | ARRAY | STRUCT | MAP + | unsupportedType=identifier ; dataType @@ -1342,8 +1343,8 @@ tableElementList ; tableElement - : colDefinition - | constraintSpec + : constraintSpec + | colDefinition ; colDefinitionList @@ -1530,11 +1531,7 @@ number ; constraintSpec - : constraintName? constraintExpression constraintCharacteristic* - ; - -constraintName - : CONSTRAINT name=errorCapturingIdentifier + : (CONSTRAINT name=errorCapturingIdentifier)? constraintExpression constraintCharacteristic* ; constraintExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2e7c0b75ddda..7790c4f06f0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3926,6 +3926,7 @@ class AstBuilder extends DataTypeAstBuilder throw QueryParsingErrors.duplicateTableColumnDescriptor( option, name, "CONSTRAINT") } + constraintSpec = Some(spec) } } @@ -5290,8 +5291,8 @@ class AstBuilder extends DataTypeAstBuilder override def visitConstraintSpec(ctx: ConstraintSpecContext): ConstraintExpression = withOrigin(ctx) { - val name = if (ctx.constraintName() != null) { - visitConstraintName(ctx.constraintName()) + val name = if (ctx.name != null) { + ctx.name.getText } else { null } @@ -5302,10 +5303,6 @@ class AstBuilder extends DataTypeAstBuilder expr.withNameAndCharacteristic(name, constraintCharacteristic) } - override def visitConstraintName(ctx: ConstraintNameContext): String = { - ctx.name.getText - } - override def visitCheckConstraint(ctx: CheckConstraintContext): CheckConstraint = withOrigin(ctx) { val condition = getOriginalText(ctx.expr) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala index 7169f4cf084d..0814a69ec49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala @@ -34,12 +34,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { ) def verifyConstraints(constraintStr: String, constraints: Constraints): Unit = { - val sql = - s""" - |$createTablePrefix - |, $constraintStr - |$createTableSuffix - |""".stripMargin + val sql = s"$createTablePrefix, $constraintStr $createTableSuffix" val parsed = parsePlan(sql) val tableSpec = UnresolvedTableSpec( @@ -60,7 +55,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { } test("Create table with two check constraints") { - val constraintStr = "CONSTRAINT c1 CHECK (a > 0) CONSTRAINT c2 CHECK (b = 'foo')" + val constraintStr = "CONSTRAINT c1 CHECK (a > 0), CONSTRAINT c2 CHECK (b = 'foo')" val constraint1 = CheckConstraint( child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), condition = "a > 0", @@ -91,16 +86,41 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" val expectedContext = ExpectedContext( fragment = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2", - start = 47, - stop = 75 + characteristic1.length + characteristic2.length + start = 33, + stop = 61 + characteristic1.length + characteristic2.length ) checkError( exception = intercept[ParseException] { - parsePlan(s"$createTablePrefix $constraintStr") + parsePlan(s"$createTablePrefix, $constraintStr $createTableSuffix") }, condition = "INVALID_CONSTRAINT_CHARACTERISTICS", parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), queryContext = Array(expectedContext)) } } + + test("Create table with column 'constraint'") { + val sql = "CREATE TABLE t (constraint STRING) USING parquet" + val columns = Seq(ColumnDefinition("constraint", StringType)) + val tableSpec = UnresolvedTableSpec( + Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), + None, None, None, None, false, Constraints(Seq.empty)) + val expected = CreateTable(tableId, columns, Seq.empty, tableSpec, false) + comparePlans(parsePlan(sql), expected) + } + + test("Create table with column 'constraint' and check constraint") { + val sql = "CREATE TABLE t (constraint STRING CONSTRAINT c1 CHECK (constraint = 'foo'))" + + " USING parquet" + val columns = Seq(ColumnDefinition("constraint", StringType)) + val constraint = CheckConstraint( + child = EqualTo(UnresolvedAttribute("constraint"), Literal("foo")), + condition = "constraint = 'foo'", + name = "c1") + val tableSpec = UnresolvedTableSpec( + Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), + None, None, None, None, false, Constraints(Seq(constraint))) + val expected = CreateTable(tableId, columns, Seq.empty, tableSpec, false) + comparePlans(parsePlan(sql), expected) + } } From 7a33e57a3a58b6a3d15985b586343c104f626415 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Mar 2025 15:19:31 -0700 Subject: [PATCH 37/65] refactor test code --- .../CreateTableConstraintParseSuite.scala | 100 ++++++++++++------ .../command/v2/CheckConstraintSuite.scala | 2 +- 2 files changed, 71 insertions(+), 31 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala index 0814a69ec49e..bb36512be85b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala @@ -25,37 +25,67 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTabl import org.apache.spark.sql.types.{IntegerType, StringType} class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { - val createTablePrefix = "CREATE TABLE t (a INT, b STRING" - val createTableSuffix = ") USING parquet" - val tableId = UnresolvedIdentifier(Seq("t")) - val columns = Seq( - ColumnDefinition("a", IntegerType), - ColumnDefinition("b", StringType) - ) - def verifyConstraints(constraintStr: String, constraints: Constraints): Unit = { - val sql = s"$createTablePrefix, $constraintStr $createTableSuffix" - - val parsed = parsePlan(sql) + def createExpectedPlan( + columns: Seq[ColumnDefinition], + constraints: Constraints): CreateTable = { + val tableId = UnresolvedIdentifier(Seq("t")) val tableSpec = UnresolvedTableSpec( Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), None, None, None, None, false, constraints) - val expected = CreateTable(tableId, columns, Seq.empty, tableSpec, false) + CreateTable(tableId, columns, Seq.empty, tableSpec, false) + } + + def verifyConstraints(sql: String, constraints: Constraints): Unit = { + val parsed = parsePlan(sql) + val columns = Seq( + ColumnDefinition("a", IntegerType), + ColumnDefinition("b", StringType) + ) + val expected = createExpectedPlan(columns = columns, constraints = constraints) comparePlans(parsed, expected) } - test("Create table with one check constraint") { - val constraintStr = "CONSTRAINT c1 CHECK (a > 0)" + test("Create table with one check constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0)) USING parquet" + val constraint = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a > 0", + name = "c1") + val constraints = Constraints(Seq(constraint)) + verifyConstraints(sql, constraints) + } + + test("Create table with one check constraint - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), b STRING) USING parquet" val constraint = CheckConstraint( child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), condition = "a > 0", name = "c1") val constraints = Constraints(Seq(constraint)) - verifyConstraints(constraintStr, constraints) + verifyConstraints(sql, constraints) + } + + + test("Create table with two check constraints - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0), " + + "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" + val constraint1 = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a > 0", + name = "c1") + val constraint2 = CheckConstraint( + child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), + condition = "b = 'foo'", + name = "c2") + val constraints = Constraints(Seq(constraint1, constraint2)) + verifyConstraints(sql, constraints) } - test("Create table with two check constraints") { - val constraintStr = "CONSTRAINT c1 CHECK (a > 0), CONSTRAINT c2 CHECK (b = 'foo')" + + test("Create table with two check constraints - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), " + + "b STRING CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" val constraint1 = CheckConstraint( child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), condition = "a > 0", @@ -65,19 +95,35 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { condition = "b = 'foo'", name = "c2") val constraints = Constraints(Seq(constraint1, constraint2)) - verifyConstraints(constraintStr, constraints) + verifyConstraints(sql, constraints) + } + + + test("Create table with valid characteristic - table level") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = s"CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0) " + + s"$enforcedStr $relyStr) USING parquet" + val constraint = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a > 0", + name = "c1", + characteristic = characteristic) + val constraints = Constraints(Seq(constraint)) + verifyConstraints(sql, constraints) + } } - test("Create table with valid characteristic") { + test("Create table with valid characteristic - column level") { validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => - val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $enforcedStr $relyStr" + val sql = s"CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0) " + + s"$enforcedStr $relyStr, b STRING) USING parquet" val constraint = CheckConstraint( child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), condition = "a > 0", name = "c1", characteristic = characteristic) val constraints = Constraints(Seq(constraint)) - verifyConstraints(constraintStr, constraints) + verifyConstraints(sql, constraints) } } @@ -91,7 +137,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { ) checkError( exception = intercept[ParseException] { - parsePlan(s"$createTablePrefix, $constraintStr $createTableSuffix") + parsePlan(s"CREATE TABLE t (a INT, b STRING, $constraintStr ) USING parquet") }, condition = "INVALID_CONSTRAINT_CHARACTERISTICS", parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), @@ -102,10 +148,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { test("Create table with column 'constraint'") { val sql = "CREATE TABLE t (constraint STRING) USING parquet" val columns = Seq(ColumnDefinition("constraint", StringType)) - val tableSpec = UnresolvedTableSpec( - Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), - None, None, None, None, false, Constraints(Seq.empty)) - val expected = CreateTable(tableId, columns, Seq.empty, tableSpec, false) + val expected = createExpectedPlan(columns, Constraints.empty) comparePlans(parsePlan(sql), expected) } @@ -117,10 +160,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { child = EqualTo(UnresolvedAttribute("constraint"), Literal("foo")), condition = "constraint = 'foo'", name = "c1") - val tableSpec = UnresolvedTableSpec( - Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), - None, None, None, None, false, Constraints(Seq(constraint))) - val expected = CreateTable(tableId, columns, Seq.empty, tableSpec, false) + val expected = createExpectedPlan(columns, Constraints(Seq(constraint))) comparePlans(parsePlan(sql), expected) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index 329b6840d363..b6d60912f5f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -110,7 +110,7 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => withNamespaceAndTable("ns", "tbl", catalog) { t => val constraintStr = s"CONSTRAINT c1 CHECK (id > 0) $characteristic" - sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing $constraintStr") + sql(s"CREATE TABLE $t (id bigint, data string, $constraintStr) $defaultUsing") val table = loadTable(catalog, "ns", "tbl") val constraint = getCheckConstraint(table) assert(constraint.name() == "c1") From 1923e1bfcdcaf38a06247fe2d65689cdd70ccac8 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Mar 2025 17:01:34 -0700 Subject: [PATCH 38/65] remove Expression Constraints --- .../catalyst/analysis/ResolveTableSpec.scala | 14 +++---- .../catalyst/expressions/constraints.scala | 38 +++---------------- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../catalyst/plans/logical/v2Commands.scala | 20 ++++------ ...eateTablePartitioningValidationSuite.scala | 4 +- .../sql/catalyst/parser/DDLParserSuite.scala | 18 ++++----- .../apache/spark/sql/classic/Catalog.scala | 4 +- .../spark/sql/classic/DataFrameWriter.scala | 8 ++-- .../spark/sql/classic/DataFrameWriterV2.scala | 6 +-- .../spark/sql/classic/DataStreamWriter.scala | 3 +- .../V2CommandsCaseSensitivitySuite.scala | 9 ++--- .../CreateTableConstraintParseSuite.scala | 22 +++++------ 12 files changed, 57 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index fc8982b93ff6..c6b21f92a0b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -67,9 +67,9 @@ object ResolveTableSpec extends Rule[LogicalPlan] { } private def analyzeConstraints( - constraints: Constraints, - fakeRelation: LogicalPlan): Constraints = { - val analyzedExpressions = constraints.children.map { + constraints: Seq[ConstraintExpression], + fakeRelation: LogicalPlan): Seq[ConstraintExpression] = { + val analyzedExpressions = constraints.map { case c: CheckConstraint => val alias = Alias(c.child, c.name)() val project = Project(Seq(alias), fakeRelation) @@ -79,10 +79,10 @@ object ResolveTableSpec extends Rule[LogicalPlan] { val analyzedExpression = analyzed collectFirst { case Project(Seq(Alias(e: Expression, _)), _) => e } - c.withNewChildren(Seq(analyzedExpression.get)) + c.withNewChildren(Seq(analyzedExpression.get)).asInstanceOf[CheckConstraint] case other => other } - Constraints(analyzedExpressions) + analyzedExpressions } /** Helper method to resolve the table specification within a logical plan. */ @@ -121,7 +121,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] { } else { u.constraints } - assert(newConstraints.childrenResolved) + // assert(newConstraints.childrenResolved) val newTableSpec = TableSpec( properties = u.properties, provider = u.provider, @@ -131,7 +131,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] { collation = u.collation, serde = u.serde, external = u.external, - constraints = newConstraints.asConstraintList) + constraints = newConstraints.map(_.asConstraint)) withNewSpec(newTableSpec) case _ => input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index afcbc516243c..443a3cb67f40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -16,18 +16,12 @@ */ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.types.{DataType, StringType} -trait ConstraintExpression extends Expression with Unevaluable { - override def nullable: Boolean = true - - override def dataType: DataType = StringType - +trait ConstraintExpression { def asConstraint: Constraint def withNameAndCharacteristic( @@ -54,8 +48,9 @@ case class CheckConstraint( condition: String, override val name: String = null, override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) - extends ConstraintExpression - with UnaryLike[Expression] { + extends UnaryExpression + with Unevaluable + with ConstraintExpression { def asConstraint: Constraint = { val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull @@ -91,29 +86,6 @@ case class CheckConstraint( ConstraintCharacteristic(enforced = Some(true), rely = Some(true)) override def sql: String = s"CONSTRAINT $name CHECK ($condition)" -} - -/* - * A list of constraints that are applied to a table. - */ -case class Constraints(children: Seq[Expression]) extends Expression with Unevaluable { - - assert(children.forall(_.isInstanceOf[ConstraintExpression])) - override def nullable: Boolean = true - - override def dataType: DataType = - throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113") - - override protected def withNewChildrenInternal( - newChildren: IndexedSeq[Expression]): Expression = { - copy(children = newChildren) - } - - def asConstraintList: Seq[Constraint] = - children.map(_.asInstanceOf[ConstraintExpression].asConstraint) -} - -object Constraints { - val empty: Constraints = Constraints(Nil) + override def dataType: DataType = StringType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7790c4f06f0b..7a833cc8909e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4787,7 +4787,7 @@ class AstBuilder extends DataTypeAstBuilder clusterBySpec.map(_.asTransform) val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external, Constraints(constraints)) + collation, serdeInfo, external, constraints) Option(ctx.query).map(plan) match { case Some(_) if columns.nonEmpty => @@ -4861,7 +4861,7 @@ class AstBuilder extends DataTypeAstBuilder clusterBySpec.map(_.asTransform) val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external = false, Constraints(constraints)) + collation, serdeInfo, external = false, constraints) Option(ctx.query).map(plan) match { case Some(_) if columns.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 914257a40c74..85b3b7b28ba3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1522,26 +1522,20 @@ case class UnresolvedTableSpec( collation: Option[String], serde: Option[SerdeInfo], external: Boolean, - constraints: Constraints) - extends BinaryExpression with Unevaluable with TableSpecBase { + constraints: Seq[ConstraintExpression]) + extends UnaryExpression with Unevaluable with TableSpecBase { override def dataType: DataType = throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113") + override def child: Expression = optionExpression + + override protected def withNewChildInternal(newChild: Expression): Expression = + this.copy(optionExpression = newChild.asInstanceOf[OptionList]) + override def simpleString(maxFields: Int): String = { this.copy(properties = Utils.redact(properties).toMap).toString } - - override def left: Expression = optionExpression - - override def right: Expression = constraints - - override protected def withNewChildrenInternal( - newLeft: Expression, newRight: Expression): Expression = - copy(optionExpression = newLeft.asInstanceOf[OptionList], - constraints = newRight.asInstanceOf[Constraints]) - - override def nullable: Boolean = true } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index bf94758267fb..0afdffb8b5e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Constraints} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode, OptionList, UnresolvedTableSpec} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, Table, TableCapability, TableCatalog} @@ -31,7 +31,7 @@ import org.apache.spark.util.ArrayImplicits._ class CreateTablePartitioningValidationSuite extends AnalysisTest { val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, None, false, - Constraints.empty) + Seq.empty) test("CreateTableAsSelect: fail missing top-level column") { val plan = CreateTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 6d4df1432d87..1589bcb8a3d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.SparkThrowable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{Constraints, EqualTo, Hex, Literal} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Hex, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.IdentityColumnSpec import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{after, first} @@ -2705,7 +2705,7 @@ class DDLParserSuite extends AnalysisTest { val createTableResult = CreateTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithDefaultValue, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false, Constraints.empty), false) + OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false) // Parse the CREATE TABLE statement twice, swapping the order of the NOT NULL and DEFAULT // options, to make sure that the parser accepts any ordering of these options. comparePlans(parsePlan( @@ -2718,7 +2718,7 @@ class DDLParserSuite extends AnalysisTest { "b STRING NOT NULL DEFAULT 'abc') USING parquet"), ReplaceTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithDefaultValue, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false, Constraints.empty), false)) + OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false)) // These ALTER TABLE statements should parse successfully. comparePlans( parsePlan("ALTER TABLE t1 ADD COLUMN x int NOT NULL DEFAULT 42"), @@ -2881,12 +2881,12 @@ class DDLParserSuite extends AnalysisTest { "CREATE TABLE my_tab(a INT, b INT NOT NULL GENERATED ALWAYS AS (a+1)) USING parquet"), CreateTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithGenerationExpr, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false, Constraints.empty), false)) + OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false)) comparePlans(parsePlan( "REPLACE TABLE my_tab(a INT, b INT NOT NULL GENERATED ALWAYS AS (a+1)) USING parquet"), ReplaceTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithGenerationExpr, Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), - OptionList(Seq.empty), None, None, None, None, false, Constraints.empty), false)) + OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false)) // Two generation expressions checkError( exception = parseException("CREATE TABLE my_tab(a INT, " + @@ -2958,7 +2958,7 @@ class DDLParserSuite extends AnalysisTest { None, None, false, - Constraints.empty + Seq.empty ), false ) @@ -2982,7 +2982,7 @@ class DDLParserSuite extends AnalysisTest { None, None, false, - Constraints.empty + Seq.empty ), false ) @@ -3275,7 +3275,7 @@ class DDLParserSuite extends AnalysisTest { Seq(ColumnDefinition("c", StringType)), Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), - None, None, Some(collation), None, false, Constraints.empty), false)) + None, None, Some(collation), None, false, Seq.empty), false)) } } @@ -3287,7 +3287,7 @@ class DDLParserSuite extends AnalysisTest { Seq(ColumnDefinition("c", StringType)), Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), - None, None, Some(collation), None, false, Constraints.empty), false)) + None, None, Some(collation), None, false, Seq.empty), false)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index 30118744ca78..3b4f6475a6bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Constraints, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -689,7 +689,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { collation = None, serde = None, external = tableType == CatalogTableType.EXTERNAL, - constraints = Constraints.empty) + constraints = Seq.empty) val plan = CreateTable( name = UnresolvedIdentifier(ident), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala index 6733f31ec017..501b4985128d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchTableException, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Constraints, Literal} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog._ @@ -214,7 +214,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, external = false, - constraints = Constraints.empty) + constraints = Seq.empty) runCommand(df.sparkSession) { CreateTableAsSelect( UnresolvedIdentifier( @@ -480,7 +480,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, external = false, - constraints = Constraints.empty) + constraints = Seq.empty) ReplaceTableAsSelect( UnresolvedIdentifier(nameParts), partitioningAsV2, @@ -502,7 +502,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram collation = extraOptions.get(TableCatalog.PROP_COLLATION), serde = None, external = false, - constraints = Constraints.empty) + constraints = Seq.empty) CreateTableAsSelect( UnresolvedIdentifier(nameParts), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala index fdedb1a50c47..01b3619f1236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Constraints, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ import org.apache.spark.sql.connector.expressions._ @@ -155,7 +155,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) collation = None, serde = None, external = false, - constraints = Constraints.empty) + constraints = Seq.empty) runCommand( CreateTableAsSelect( UnresolvedIdentifier(tableName), @@ -222,7 +222,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) collation = None, serde = None, external = false, - constraints = Constraints.empty) + constraints = Seq.empty) runCommand(ReplaceTableAsSelect( UnresolvedIdentifier(tableName), partitioning.getOrElse(Seq.empty) ++ clustering, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala index 4eff6adb28b9..471c5feadaab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.{streaming, Dataset => DS, ForeachWriter} import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.Constraints import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -177,7 +176,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D None, None, external = false, - constraints = Constraints.empty) + constraints = Seq.empty) val cmd = CreateTable( UnresolvedIdentifier(originalMultipartIdentifier), ds.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index 1f012099bda4..300492577b1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedFieldName, UnresolvedFieldPosition, UnresolvedIdentifier} -import org.apache.spark.sql.catalyst.expressions.Constraints import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumns, AlterColumnSpec, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, OptionList, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect, UnresolvedTableSpec} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes @@ -55,7 +54,7 @@ class V2CommandsCaseSensitivitySuite Seq("ID", "iD").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false, Constraints.empty) + None, None, None, None, false, Seq.empty) val plan = CreateTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.identity(ref) :: Nil, @@ -80,7 +79,7 @@ class V2CommandsCaseSensitivitySuite Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false, Constraints.empty) + None, None, None, None, false, Seq.empty) val plan = CreateTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, ref) :: Nil, @@ -106,7 +105,7 @@ class V2CommandsCaseSensitivitySuite Seq("ID", "iD").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false, Constraints.empty) + None, None, None, None, false, Seq.empty) val plan = ReplaceTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.identity(ref) :: Nil, @@ -131,7 +130,7 @@ class V2CommandsCaseSensitivitySuite Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), - None, None, None, None, false, Constraints.empty) + None, None, None, None, false, Seq.empty) val plan = ReplaceTableAsSelect( UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, ref) :: Nil, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala index bb36512be85b..fd815e10d351 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} -import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, Constraints, EqualTo, GreaterThan, Literal} +import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, ConstraintExpression, EqualTo, GreaterThan, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} @@ -28,7 +28,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { def createExpectedPlan( columns: Seq[ColumnDefinition], - constraints: Constraints): CreateTable = { + constraints: Seq[ConstraintExpression]): CreateTable = { val tableId = UnresolvedIdentifier(Seq("t")) val tableSpec = UnresolvedTableSpec( Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), @@ -36,7 +36,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { CreateTable(tableId, columns, Seq.empty, tableSpec, false) } - def verifyConstraints(sql: String, constraints: Constraints): Unit = { + def verifyConstraints(sql: String, constraints: Seq[ConstraintExpression]): Unit = { val parsed = parsePlan(sql) val columns = Seq( ColumnDefinition("a", IntegerType), @@ -52,7 +52,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), condition = "a > 0", name = "c1") - val constraints = Constraints(Seq(constraint)) + val constraints = Seq(constraint) verifyConstraints(sql, constraints) } @@ -62,7 +62,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), condition = "a > 0", name = "c1") - val constraints = Constraints(Seq(constraint)) + val constraints = Seq(constraint) verifyConstraints(sql, constraints) } @@ -78,7 +78,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), condition = "b = 'foo'", name = "c2") - val constraints = Constraints(Seq(constraint1, constraint2)) + val constraints = Seq(constraint1, constraint2) verifyConstraints(sql, constraints) } @@ -94,7 +94,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), condition = "b = 'foo'", name = "c2") - val constraints = Constraints(Seq(constraint1, constraint2)) + val constraints = Seq(constraint1, constraint2) verifyConstraints(sql, constraints) } @@ -108,7 +108,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { condition = "a > 0", name = "c1", characteristic = characteristic) - val constraints = Constraints(Seq(constraint)) + val constraints = Seq(constraint) verifyConstraints(sql, constraints) } } @@ -122,7 +122,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { condition = "a > 0", name = "c1", characteristic = characteristic) - val constraints = Constraints(Seq(constraint)) + val constraints = Seq(constraint) verifyConstraints(sql, constraints) } } @@ -148,7 +148,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { test("Create table with column 'constraint'") { val sql = "CREATE TABLE t (constraint STRING) USING parquet" val columns = Seq(ColumnDefinition("constraint", StringType)) - val expected = createExpectedPlan(columns, Constraints.empty) + val expected = createExpectedPlan(columns, Seq.empty) comparePlans(parsePlan(sql), expected) } @@ -160,7 +160,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { child = EqualTo(UnresolvedAttribute("constraint"), Literal("foo")), condition = "constraint = 'foo'", name = "c1") - val expected = createExpectedPlan(columns, Constraints(Seq(constraint))) + val expected = createExpectedPlan(columns, Seq(constraint)) comparePlans(parsePlan(sql), expected) } } From 3c6cac54db1286c62222ac315eb3cfc6e9fc591f Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Mar 2025 17:10:59 -0700 Subject: [PATCH 39/65] rename ConstraintExpression as TableConstraint --- .../spark/sql/catalyst/analysis/ResolveTableSpec.scala | 4 ++-- .../spark/sql/catalyst/expressions/constraints.scala | 8 ++++---- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 10 +++++----- .../spark/sql/catalyst/plans/logical/v2Commands.scala | 2 +- .../command/CreateTableConstraintParseSuite.scala | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index c6b21f92a0b5..e4ce620717b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -67,8 +67,8 @@ object ResolveTableSpec extends Rule[LogicalPlan] { } private def analyzeConstraints( - constraints: Seq[ConstraintExpression], - fakeRelation: LogicalPlan): Seq[ConstraintExpression] = { + constraints: Seq[TableConstraint], + fakeRelation: LogicalPlan): Seq[TableConstraint] = { val analyzedExpressions = constraints.map { case c: CheckConstraint => val alias = Alias(c.child, c.name)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index 443a3cb67f40..99ec957cf8bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -21,12 +21,12 @@ import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.types.{DataType, StringType} -trait ConstraintExpression { +trait TableConstraint { def asConstraint: Constraint def withNameAndCharacteristic( name: String, - c: ConstraintCharacteristic): ConstraintExpression + c: ConstraintCharacteristic): TableConstraint def name: String @@ -50,7 +50,7 @@ case class CheckConstraint( override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) extends UnaryExpression with Unevaluable - with ConstraintExpression { + with TableConstraint { def asConstraint: Constraint = { val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull @@ -73,7 +73,7 @@ case class CheckConstraint( override def withNameAndCharacteristic( name: String, - c: ConstraintCharacteristic): ConstraintExpression = { + c: ConstraintCharacteristic): TableConstraint = { copy(name = name, characteristic = c) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7a833cc8909e..b77927fe992f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4166,9 +4166,9 @@ class AstBuilder extends DataTypeAstBuilder Seq[Transform], Seq[ColumnDefinition], Option[BucketSpec], Map[String, String], OptionList, Option[String], Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec]) - type ColumnAndConstraint = (ColumnDefinition, Option[ConstraintExpression]) + type ColumnAndConstraint = (ColumnDefinition, Option[TableConstraint]) - type TableElementList = (Seq[ColumnDefinition], Seq[ConstraintExpression]) + type TableElementList = (Seq[ColumnDefinition], Seq[TableConstraint]) /** * Validate a create table statement and return the [[TableIdentifier]]. @@ -4715,7 +4715,7 @@ class AstBuilder extends DataTypeAstBuilder return (Nil, Nil) } val columnDefs = new ArrayBuffer[ColumnDefinition]() - val constraints = new ArrayBuffer[ConstraintExpression]() + val constraints = new ArrayBuffer[TableConstraint]() ctx.tableElement().asScala.foreach { element => if (element.constraintSpec() != null) { @@ -5289,7 +5289,7 @@ class AstBuilder extends DataTypeAstBuilder AlterTableCollation(table, visitCollationSpec(ctx.collationSpec())) } - override def visitConstraintSpec(ctx: ConstraintSpecContext): ConstraintExpression = + override def visitConstraintSpec(ctx: ConstraintSpecContext): TableConstraint = withOrigin(ctx) { val name = if (ctx.name != null) { ctx.name.getText @@ -5298,7 +5298,7 @@ class AstBuilder extends DataTypeAstBuilder } val constraintCharacteristic = visitConstraintCharacteristic(ctx) val expr = - visitConstraintExpression(ctx.constraintExpression()).asInstanceOf[ConstraintExpression] + visitConstraintExpression(ctx.constraintExpression()).asInstanceOf[TableConstraint] expr.withNameAndCharacteristic(name, constraintCharacteristic) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 85b3b7b28ba3..07b6e912b584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1522,7 +1522,7 @@ case class UnresolvedTableSpec( collation: Option[String], serde: Option[SerdeInfo], external: Boolean, - constraints: Seq[ConstraintExpression]) + constraints: Seq[TableConstraint]) extends UnaryExpression with Unevaluable with TableSpecBase { override def dataType: DataType = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala index fd815e10d351..7eb651186f71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} -import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, ConstraintExpression, EqualTo, GreaterThan, Literal} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} @@ -28,7 +28,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { def createExpectedPlan( columns: Seq[ColumnDefinition], - constraints: Seq[ConstraintExpression]): CreateTable = { + constraints: Seq[TableConstraint]): CreateTable = { val tableId = UnresolvedIdentifier(Seq("t")) val tableSpec = UnresolvedTableSpec( Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), @@ -36,7 +36,7 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { CreateTable(tableId, columns, Seq.empty, tableSpec, false) } - def verifyConstraints(sql: String, constraints: Seq[ConstraintExpression]): Unit = { + def verifyConstraints(sql: String, constraints: Seq[TableConstraint]): Unit = { val parsed = parsePlan(sql) val columns = Seq( ColumnDefinition("a", IntegerType), From 7aea90c23ab2a96029a142e7aab124fd269467d5 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Mar 2025 21:11:50 -0700 Subject: [PATCH 40/65] refactor syntax as per standard --- .../sql/catalyst/parser/SqlBaseParser.g4 | 35 +++++++++++++------ .../sql/catalyst/parser/AstBuilder.scala | 29 +++++++++------ 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 1fab9ac455b4..c000da0b3b3d 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -261,7 +261,7 @@ statement | ALTER TABLE identifierReference (clusterBySpec | CLUSTER BY NONE) #alterClusterBy | ALTER TABLE identifierReference collationSpec #alterTableCollation - | ALTER TABLE identifierReference ADD constraintSpec #addTableConstraint + | ALTER TABLE identifierReference ADD tableConstraintDefinition #addTableConstraint | ALTER TABLE identifierReference DROP CONSTRAINT (IF EXISTS)? name=identifier (RESTRICT | CASCADE)? #dropTableConstraint @@ -1343,7 +1343,7 @@ tableElementList ; tableElement - : constraintSpec + : tableConstraintDefinition | colDefinition ; @@ -1360,7 +1360,7 @@ colDefinitionOption | defaultExpression | generationExpression | commentSpec - | constraintSpec + | columnConstraint ; generationExpression @@ -1530,14 +1530,22 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; -constraintSpec - : (CONSTRAINT name=errorCapturingIdentifier)? constraintExpression constraintCharacteristic* +columnConstraintDefinition + : (CONSTRAINT name=errorCapturingIdentifier)? columnConstraint constraintCharacteristic* ; -constraintExpression +columnConstraint + : uniqueSpec + | referenceSpec + ; + +tableConstraintDefinition + : (CONSTRAINT name=errorCapturingIdentifier)? tableConstraint constraintCharacteristic* + ; + +tableConstraint : checkConstraint | uniqueConstraint - | primaryKeyConstraint | foreignKeyConstraint ; @@ -1545,16 +1553,21 @@ checkConstraint : CHECK LEFT_PAREN (expr=booleanExpression) RIGHT_PAREN ; +uniqueSpec + : UNIQUE + | PRIMARY KEY + ; + uniqueConstraint - : UNIQUE identifierList + : uniqueSpec identifierList ; -primaryKeyConstraint - : PRIMARY KEY identifierList +referenceSpec + : REFERENCES multipartIdentifier (LEFT_PAREN parentColumns=identifierList RIGHT_PAREN)? ; foreignKeyConstraint - : FOREIGN KEY identifierList REFERENCES table=multipartIdentifier identifierList + : FOREIGN KEY LEFT_PAREN identifierList RIGHT_PAREN referenceSpec ; constraintCharacteristic diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b77927fe992f..d3a077f28a60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3887,7 +3887,7 @@ class AstBuilder extends DataTypeAstBuilder var defaultExpression: Option[DefaultExpressionContext] = None var generationExpression: Option[GenerationExpressionContext] = None var commentSpec: Option[CommentSpecContext] = None - var constraintSpec: Option[ConstraintSpecContext] = None + var columnConstraint: Option[ColumnConstraintContext] = None ctx.colDefinitionOption().asScala.foreach { option => if (option.NULL != null) { blockBang(option.errorCapturingNot) @@ -3921,12 +3921,12 @@ class AstBuilder extends DataTypeAstBuilder } commentSpec = Some(spec) } - Option(option.constraintSpec()).foreach { spec => - if (constraintSpec.isDefined) { + Option(option.columnConstraint()).foreach { spec => + if (columnConstraint.isDefined) { throw QueryParsingErrors.duplicateTableColumnDescriptor( option, name, "CONSTRAINT") } - constraintSpec = Some(spec) + columnConstraint = Some(spec) } } @@ -3944,10 +3944,16 @@ class AstBuilder extends DataTypeAstBuilder case ctx: IdentityColumnContext => visitIdentityColumn(ctx, dataType) } ) - val constraint = constraintSpec.map(visitConstraintSpec) + val constraint = columnConstraint.map(c => visitColumnConstraint(name, c)) (columnDef, constraint) } + private def visitColumnConstraint( + columnName: String, + ctx: ColumnConstraintContext): TableConstraint = { + null + } + /** * Create a location string. */ @@ -4718,8 +4724,8 @@ class AstBuilder extends DataTypeAstBuilder val constraints = new ArrayBuffer[TableConstraint]() ctx.tableElement().asScala.foreach { element => - if (element.constraintSpec() != null) { - constraints += visitConstraintSpec(element.constraintSpec()) + if (element.tableConstraintDefinition() != null) { + constraints += visitTableConstraintDefinition(element.tableConstraintDefinition()) } else { val (colDef, constraintOpt) = visitColDefinition(element.colDefinition()) columnDefs += colDef @@ -5289,7 +5295,8 @@ class AstBuilder extends DataTypeAstBuilder AlterTableCollation(table, visitCollationSpec(ctx.collationSpec())) } - override def visitConstraintSpec(ctx: ConstraintSpecContext): TableConstraint = + override def visitTableConstraintDefinition( + ctx: TableConstraintDefinitionContext): TableConstraint = withOrigin(ctx) { val name = if (ctx.name != null) { ctx.name.getText @@ -5298,7 +5305,7 @@ class AstBuilder extends DataTypeAstBuilder } val constraintCharacteristic = visitConstraintCharacteristic(ctx) val expr = - visitConstraintExpression(ctx.constraintExpression()).asInstanceOf[TableConstraint] + visitTableConstraint(ctx.tableConstraint()).asInstanceOf[TableConstraint] expr.withNameAndCharacteristic(name, constraintCharacteristic) } @@ -5312,7 +5319,7 @@ class AstBuilder extends DataTypeAstBuilder } private def visitConstraintCharacteristic( - ctx: ConstraintSpecContext): ConstraintCharacteristic = { + ctx: TableConstraintDefinitionContext): ConstraintCharacteristic = { var enforcement: Option[String] = None var rely: Option[String] = None ctx.constraintCharacteristic().asScala.foreach { @@ -5351,7 +5358,7 @@ class AstBuilder extends DataTypeAstBuilder withOrigin(ctx) { val table = createUnresolvedTable( ctx.identifierReference, "ALTER TABLE ... ADD CONSTRAINT") - visitConstraintSpec(ctx.constraintSpec) match { + visitTableConstraintDefinition(ctx.tableConstraintDefinition()) match { case c: CheckConstraint => AddCheckConstraint(table, c) } From 78c3904ea744e218142173321e22f5d51a000d56 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Mar 2025 21:14:34 -0700 Subject: [PATCH 41/65] remove column check constraint test cases --- .../CreateTableConstraintParseSuite.scala | 54 ------------------- 1 file changed, 54 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala index 7eb651186f71..c9eafe69de21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala @@ -56,17 +56,6 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { verifyConstraints(sql, constraints) } - test("Create table with one check constraint - column level") { - val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), b STRING) USING parquet" - val constraint = CheckConstraint( - child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a > 0", - name = "c1") - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with two check constraints - table level") { val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0), " + "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" @@ -82,23 +71,6 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { verifyConstraints(sql, constraints) } - - test("Create table with two check constraints - column level") { - val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), " + - "b STRING CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" - val constraint1 = CheckConstraint( - child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a > 0", - name = "c1") - val constraint2 = CheckConstraint( - child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), - condition = "b = 'foo'", - name = "c2") - val constraints = Seq(constraint1, constraint2) - verifyConstraints(sql, constraints) - } - - test("Create table with valid characteristic - table level") { validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => val sql = s"CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0) " + @@ -113,20 +85,6 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { } } - test("Create table with valid characteristic - column level") { - validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => - val sql = s"CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0) " + - s"$enforcedStr $relyStr, b STRING) USING parquet" - val constraint = CheckConstraint( - child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a > 0", - name = "c1", - characteristic = characteristic) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - } - test("Create table with invalid characteristic") { invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" @@ -151,16 +109,4 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { val expected = createExpectedPlan(columns, Seq.empty) comparePlans(parsePlan(sql), expected) } - - test("Create table with column 'constraint' and check constraint") { - val sql = "CREATE TABLE t (constraint STRING CONSTRAINT c1 CHECK (constraint = 'foo'))" + - " USING parquet" - val columns = Seq(ColumnDefinition("constraint", StringType)) - val constraint = CheckConstraint( - child = EqualTo(UnresolvedAttribute("constraint"), Literal("foo")), - condition = "constraint = 'foo'", - name = "c1") - val expected = createExpectedPlan(columns, Seq(constraint)) - comparePlans(parsePlan(sql), expected) - } } From 20d7cf9bc1860bf738baef57ef0912fa73b4a5b4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Mar 2025 21:55:06 -0700 Subject: [PATCH 42/65] add pk&fk --- .../catalyst/expressions/constraints.scala | 75 ++++++++++++++++++- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index 99ec957cf8bd..fa886754f803 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.constraints.Constraint +import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.types.{DataType, StringType} trait TableConstraint { @@ -35,6 +37,12 @@ trait TableConstraint { def defaultName: String def defaultConstraintCharacteristic: ConstraintCharacteristic + + protected def getCharacteristicValues: (Boolean, Boolean) = { + val rely = characteristic.rely.getOrElse(defaultConstraintCharacteristic.rely.get) + val enforced = characteristic.enforced.getOrElse(defaultConstraintCharacteristic.enforced.get) + (rely, enforced) + } } case class ConstraintCharacteristic(enforced: Option[Boolean], rely: Option[Boolean]) @@ -54,9 +62,7 @@ case class CheckConstraint( def asConstraint: Constraint = { val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull - val rely = characteristic.rely.getOrElse(defaultConstraintCharacteristic.rely.get) - val enforced = - characteristic.enforced.getOrElse(defaultConstraintCharacteristic.enforced.get) + val (rely, enforced) = getCharacteristicValues val constraintName = if (name == null) defaultName else name Constraint .check(constraintName) @@ -89,3 +95,66 @@ case class CheckConstraint( override def dataType: DataType = StringType } + +case class PrimaryKeyConstraint( + columns: Seq[String], + override val name: String = null, + override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) + extends TableConstraint { + + override def asConstraint: Constraint = { + val (rely, enforced) = getCharacteristicValues + val constraintName = if (name == null) defaultName else name + Constraint + .primaryKey(constraintName, columns.map(FieldReference.column).toArray) + .rely(rely) + .enforced(enforced) + .validationStatus(Constraint.ValidationStatus.UNVALIDATED) + .build() + } + + override def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic): TableConstraint = { + copy(name = name, characteristic = c) + } + + override def defaultName: String = "pk" + + override def defaultConstraintCharacteristic: ConstraintCharacteristic = + ConstraintCharacteristic(enforced = Some(true), rely = Some(true)) +} + +case class ForeignKeyConstraint( + override val name: String = null, + childColumns: Seq[String] = Seq.empty, + parentTable: Identifier = null, + parentColumns: Seq[String] = Seq.empty, + override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) + extends TableConstraint { + + override def asConstraint: Constraint = { + val (rely, enforced) = getCharacteristicValues + val constraintName = if (name == null) defaultName else name + Constraint + .foreignKey(constraintName, + childColumns.map(FieldReference.column).toArray, + parentTable, + parentColumns.map(FieldReference.column).toArray) + .rely(rely) + .enforced(enforced) + .validationStatus(Constraint.ValidationStatus.UNVALIDATED) + .build() + } + + override def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic): TableConstraint = { + copy(name = name, characteristic = c) + } + + override def defaultName: String = "fk" + + override def defaultConstraintCharacteristic: ConstraintCharacteristic = + ConstraintCharacteristic(enforced = Some(true), rely = Some(true)) +} From af3fb9177677715835da6357251c408f2c164732 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Mar 2025 23:33:20 -0700 Subject: [PATCH 43/65] parse column constraint --- .../catalyst/expressions/constraints.scala | 42 ++++++++++++++++--- .../sql/catalyst/parser/AstBuilder.scala | 22 +++++++++- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index fa886754f803..fe68a7166746 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder -import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.types.{DataType, StringType} @@ -122,16 +121,49 @@ case class PrimaryKeyConstraint( override def defaultName: String = "pk" override def defaultConstraintCharacteristic: ConstraintCharacteristic = - ConstraintCharacteristic(enforced = Some(true), rely = Some(true)) + ConstraintCharacteristic(enforced = Some(false), rely = Some(true)) +} + +case class UniqueConstraint( + columns: Seq[String], + override val name: String = null, + override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) + extends TableConstraint { + + override def asConstraint: Constraint = { + val (rely, enforced) = getCharacteristicValues + val constraintName = if (name == null) defaultName else name + Constraint + .unique(constraintName, columns.map(FieldReference.column).toArray) + .rely(rely) + .enforced(enforced) + .validationStatus(Constraint.ValidationStatus.UNVALIDATED) + .build() + } + + override def withNameAndCharacteristic( + name: String, + c: ConstraintCharacteristic): TableConstraint = { + copy(name = name, characteristic = c) + } + + override def defaultName: String = + throw new AnalysisException( + errorClass = "INVALID_UNIQUE_CONSTRAINT.MISSING_NAME", + messageParameters = Map.empty) + + override def defaultConstraintCharacteristic: ConstraintCharacteristic = + ConstraintCharacteristic(enforced = Some(false), rely = Some(true)) } case class ForeignKeyConstraint( override val name: String = null, childColumns: Seq[String] = Seq.empty, - parentTable: Identifier = null, + parentTableId: Seq[String] = Seq.empty, parentColumns: Seq[String] = Seq.empty, override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) extends TableConstraint { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def asConstraint: Constraint = { val (rely, enforced) = getCharacteristicValues @@ -139,7 +171,7 @@ case class ForeignKeyConstraint( Constraint .foreignKey(constraintName, childColumns.map(FieldReference.column).toArray, - parentTable, + parentTableId.asIdentifier, parentColumns.map(FieldReference.column).toArray) .rely(rely) .enforced(enforced) @@ -156,5 +188,5 @@ case class ForeignKeyConstraint( override def defaultName: String = "fk" override def defaultConstraintCharacteristic: ConstraintCharacteristic = - ConstraintCharacteristic(enforced = Some(true), rely = Some(true)) + ConstraintCharacteristic(enforced = Some(false), rely = Some(true)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d3a077f28a60..7029abd8520d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3951,7 +3951,27 @@ class AstBuilder extends DataTypeAstBuilder private def visitColumnConstraint( columnName: String, ctx: ColumnConstraintContext): TableConstraint = { - null + val columns = Seq(columnName) + if (ctx.uniqueSpec() != null) { + if (ctx.uniqueSpec().UNIQUE() != null) { + UniqueConstraint(columns) + } else { + PrimaryKeyConstraint(columns) + } + } else { + assert(ctx.referenceSpec() != null) + val (tableId, refColumns) = visitReferenceSpec(ctx.referenceSpec()) + ForeignKeyConstraint( + childColumns = columns, + parentTableId = tableId, + parentColumns = refColumns) + } + } + + override def visitReferenceSpec(ctx: ReferenceSpecContext): (Seq[String], Seq[String]) = { + val tableId = visitMultipartIdentifier(ctx.multipartIdentifier()) + val refColumns = visitIdentifierList(ctx.parentColumns) + (tableId, refColumns) } /** From 337c35f3a979e78496b3424d298ff396ff02d9b7 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 27 Mar 2025 16:58:14 -0700 Subject: [PATCH 44/65] support PK & Unique; add tests --- .../resources/error/error-conditions.json | 6 ++ .../spark/sql/errors/QueryParsingErrors.scala | 4 + .../sql/catalyst/parser/AstBuilder.scala | 61 +++++++----- .../CreateTableConstraintParseSuite.scala | 94 +++++++++++++++++++ 4 files changed, 144 insertions(+), 21 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index bb24ed027c0c..b042bb134d76 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3927,6 +3927,12 @@ ], "sqlState" : "42P20" }, + "MULTIPLE_PRIMARY_KEYS" : { + "message" : [ + "Multiple primary keys are defined. Please ensure that only one primary key is defined for the table." + ], + "sqlState" : "42K0E" + }, "MULTIPLE_QUERY_RESULT_CLAUSES_WITH_PIPE_OPERATORS" : { "message" : [ " and cannot coexist in the same SQL pipe operator using '|>'. Please separate the multiple result clauses into separate pipe operators and then retry the query again." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index e41cdfaaf1e7..a73ec2041c01 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -807,4 +807,8 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { messageParameters = Map("constraint" -> constraint), ctx) } + + def multiplePrimaryKeysError(ctx: ParserRuleContext): Throwable = { + new ParseException(errorClass = "MULTIPLE_PRIMARY_KEYS", ctx) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7029abd8520d..07d55fb6357b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3953,11 +3953,7 @@ class AstBuilder extends DataTypeAstBuilder ctx: ColumnConstraintContext): TableConstraint = { val columns = Seq(columnName) if (ctx.uniqueSpec() != null) { - if (ctx.uniqueSpec().UNIQUE() != null) { - UniqueConstraint(columns) - } else { - PrimaryKeyConstraint(columns) - } + visitUniqueSpec(ctx.uniqueSpec(), columns) } else { assert(ctx.referenceSpec() != null) val (tableId, refColumns) = visitReferenceSpec(ctx.referenceSpec()) @@ -3968,6 +3964,15 @@ class AstBuilder extends DataTypeAstBuilder } } + private def visitUniqueSpec(ctx: UniqueSpecContext, columns: Seq[String]): TableConstraint = + withOrigin(ctx) { + if (ctx.UNIQUE() != null) { + UniqueConstraint(columns) + } else { + PrimaryKeyConstraint(columns) + } + } + override def visitReferenceSpec(ctx: ReferenceSpecContext): (Seq[String], Seq[String]) = { val tableId = visitMultipartIdentifier(ctx.multipartIdentifier()) val refColumns = visitIdentifierList(ctx.parentColumns) @@ -4736,25 +4741,32 @@ class AstBuilder extends DataTypeAstBuilder } } - override def visitTableElementList(ctx: TableElementListContext): TableElementList = { - if (ctx == null) { - return (Nil, Nil) - } - val columnDefs = new ArrayBuffer[ColumnDefinition]() - val constraints = new ArrayBuffer[TableConstraint]() + override def visitTableElementList(ctx: TableElementListContext): TableElementList = + withOrigin(ctx) { + if (ctx == null) { + return (Nil, Nil) + } + val columnDefs = new ArrayBuffer[ColumnDefinition]() + val constraints = new ArrayBuffer[TableConstraint]() - ctx.tableElement().asScala.foreach { element => - if (element.tableConstraintDefinition() != null) { - constraints += visitTableConstraintDefinition(element.tableConstraintDefinition()) - } else { - val (colDef, constraintOpt) = visitColDefinition(element.colDefinition()) - columnDefs += colDef - constraintOpt.foreach(constraints += _) + ctx.tableElement().asScala.foreach { element => + if (element.tableConstraintDefinition() != null) { + constraints += visitTableConstraintDefinition(element.tableConstraintDefinition()) + } else { + val (colDef, constraintOpt) = visitColDefinition(element.colDefinition()) + columnDefs += colDef + constraintOpt.foreach(constraints += _) + } } - } - (columnDefs.toSeq, constraints.toSeq) - } + // check if there are multiple primary keys + val primaryKeys = constraints.filter(_.isInstanceOf[PrimaryKeyConstraint]) + if (primaryKeys.size > 1) { + throw QueryParsingErrors.multiplePrimaryKeysError(ctx) + } + + (columnDefs.toSeq, constraints.toSeq) + } /** * Create a table, returning a [[CreateTable]] or [[CreateTableAsSelect]] logical plan. @@ -5338,6 +5350,13 @@ class AstBuilder extends DataTypeAstBuilder condition = condition) } + + override def visitUniqueConstraint(ctx: UniqueConstraintContext): TableConstraint = + withOrigin(ctx) { + val columns = visitIdentifierList(ctx.identifierList()) + visitUniqueSpec(ctx.uniqueSpec(), columns) + } + private def visitConstraintCharacteristic( ctx: TableConstraintDefinitionContext): ConstraintCharacteristic = { var enforcement: Option[String] = None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala index c9eafe69de21..5573470c7913 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala @@ -109,4 +109,98 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { val expected = createExpectedPlan(columns, Seq.empty) comparePlans(parsePlan(sql), expected) } + + + test("Create table with primary key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a)) USING parquet" + val constraint = PrimaryKeyConstraint(columns = Seq("a")) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named primary key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT pk1 PRIMARY KEY (a)) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a"), + name = "pk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with composite primary key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a, b)) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a", "b") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with primary key - column level") { + val sql = "CREATE TABLE t (a INT PRIMARY KEY, b STRING) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with multiple primary keys should fail") { + val expectedContext = ExpectedContext( + fragment = "a INT PRIMARY KEY, b STRING, PRIMARY KEY (b)", + start = 16, + stop = 59 + ) + checkError( + exception = intercept[ParseException] { + parsePlan("CREATE TABLE t (a INT PRIMARY KEY, b STRING, PRIMARY KEY (b)) USING parquet") + }, + condition = "MULTIPLE_PRIMARY_KEYS", + parameters = Map.empty[String, String], + queryContext = Array(expectedContext)) + } + + test("Create table with unique constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, UNIQUE (a)) USING parquet" + val constraint = UniqueConstraint(columns = Seq("a")) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named unique constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT uk1 UNIQUE (a)) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a"), + name = "uk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with composite unique constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, UNIQUE (a, b)) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a", "b") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with unique constraint - column level") { + val sql = "CREATE TABLE t (a INT UNIQUE, b STRING) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with multiple unique constraints") { + val sql = "CREATE TABLE t (a INT UNIQUE, b STRING, UNIQUE (b)) USING parquet" + val constraint1 = UniqueConstraint(columns = Seq("a")) + val constraint2 = UniqueConstraint(columns = Seq("b")) + val constraints = Seq(constraint1, constraint2) + verifyConstraints(sql, constraints) + } } From c8edfd01adb020e8fc8570d20e31fcf48f429b0a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 28 Mar 2025 08:30:37 -0700 Subject: [PATCH 45/65] create table with FK --- .../sql/catalyst/parser/SqlBaseParser.g4 | 6 +- .../sql/catalyst/parser/AstBuilder.scala | 44 +++++++++--- .../CreateTableConstraintParseSuite.scala | 68 +++++++++++++++++++ 3 files changed, 107 insertions(+), 11 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index c000da0b3b3d..f402e9625063 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1360,7 +1360,7 @@ colDefinitionOption | defaultExpression | generationExpression | commentSpec - | columnConstraint + | columnConstraintDefinition ; generationExpression @@ -1563,11 +1563,11 @@ uniqueConstraint ; referenceSpec - : REFERENCES multipartIdentifier (LEFT_PAREN parentColumns=identifierList RIGHT_PAREN)? + : REFERENCES multipartIdentifier (parentColumns=identifierList)? ; foreignKeyConstraint - : FOREIGN KEY LEFT_PAREN identifierList RIGHT_PAREN referenceSpec + : FOREIGN KEY identifierList referenceSpec ; constraintCharacteristic diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 07d55fb6357b..bff5c1228a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3887,7 +3887,7 @@ class AstBuilder extends DataTypeAstBuilder var defaultExpression: Option[DefaultExpressionContext] = None var generationExpression: Option[GenerationExpressionContext] = None var commentSpec: Option[CommentSpecContext] = None - var columnConstraint: Option[ColumnConstraintContext] = None + var columnConstraint: Option[ColumnConstraintDefinitionContext] = None ctx.colDefinitionOption().asScala.foreach { option => if (option.NULL != null) { blockBang(option.errorCapturingNot) @@ -3921,12 +3921,12 @@ class AstBuilder extends DataTypeAstBuilder } commentSpec = Some(spec) } - Option(option.columnConstraint()).foreach { spec => + Option(option.columnConstraintDefinition()).foreach { definition => if (columnConstraint.isDefined) { throw QueryParsingErrors.duplicateTableColumnDescriptor( option, name, "CONSTRAINT") } - columnConstraint = Some(spec) + columnConstraint = Some(definition) } } @@ -3944,10 +3944,27 @@ class AstBuilder extends DataTypeAstBuilder case ctx: IdentityColumnContext => visitIdentityColumn(ctx, dataType) } ) - val constraint = columnConstraint.map(c => visitColumnConstraint(name, c)) + val constraint = columnConstraint.map(c => visitColumnConstraintDefinition(name, c)) (columnDef, constraint) } + private def visitColumnConstraintDefinition( + columnName: String, + ctx: ColumnConstraintDefinitionContext): TableConstraint = { + withOrigin(ctx) { + val name = if (ctx.name != null) { + ctx.name.getText + } else { + null + } + val constraintCharacteristic = + visitConstraintCharacteristics(ctx.constraintCharacteristic().asScala.toSeq) + val expr = visitColumnConstraint(columnName, ctx.columnConstraint()) + + expr.withNameAndCharacteristic(name, constraintCharacteristic) + } + } + private def visitColumnConstraint( columnName: String, ctx: ColumnConstraintContext): TableConstraint = { @@ -5335,7 +5352,8 @@ class AstBuilder extends DataTypeAstBuilder } else { null } - val constraintCharacteristic = visitConstraintCharacteristic(ctx) + val constraintCharacteristic = + visitConstraintCharacteristics(ctx.constraintCharacteristic().asScala.toSeq) val expr = visitTableConstraint(ctx.tableConstraint()).asInstanceOf[TableConstraint] @@ -5357,11 +5375,21 @@ class AstBuilder extends DataTypeAstBuilder visitUniqueSpec(ctx.uniqueSpec(), columns) } - private def visitConstraintCharacteristic( - ctx: TableConstraintDefinitionContext): ConstraintCharacteristic = { + override def visitForeignKeyConstraint(ctx: ForeignKeyConstraintContext): TableConstraint = + withOrigin(ctx) { + val columns = visitIdentifierList(ctx.identifierList()) + val (parentTableId, parentColumns) = visitReferenceSpec(ctx.referenceSpec()) + ForeignKeyConstraint( + childColumns = columns, + parentTableId = parentTableId, + parentColumns = parentColumns) + } + + private def visitConstraintCharacteristics( + constraintCharacteristics: Seq[ConstraintCharacteristicContext]): ConstraintCharacteristic = { var enforcement: Option[String] = None var rely: Option[String] = None - ctx.constraintCharacteristic().asScala.foreach { + constraintCharacteristics.foreach { case e if e.enforcedCharacteristic() != null => val text = getOriginalText(e.enforcedCharacteristic()).toUpperCase(Locale.ROOT) if (enforcement.isDefined) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala index 5573470c7913..4c003c1b79b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala @@ -146,6 +146,16 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { verifyConstraints(sql, constraints) } + test("Create table with named primary key - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT pk1 PRIMARY KEY, b STRING) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a"), + name = "pk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + test("Create table with multiple primary keys should fail") { val expectedContext = ExpectedContext( fragment = "a INT PRIMARY KEY, b STRING, PRIMARY KEY (b)", @@ -196,6 +206,16 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { verifyConstraints(sql, constraints) } + test("Create table with named unique constraint - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT uk1 UNIQUE, b STRING) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a"), + name = "uk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + test("Create table with multiple unique constraints") { val sql = "CREATE TABLE t (a INT UNIQUE, b STRING, UNIQUE (b)) USING parquet" val constraint1 = UniqueConstraint(columns = Seq("a")) @@ -203,4 +223,52 @@ class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { val constraints = Seq(constraint1, constraint2) verifyConstraints(sql, constraints) } + + test("Create table with foreign key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING," + + " FOREIGN KEY (a) REFERENCES parent(id)) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named foreign key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT fk1 FOREIGN KEY (a)" + + " REFERENCES parent(id)) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id"), + name = "fk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with foreign key - column level") { + val sql = "CREATE TABLE t (a INT REFERENCES parent(id), b STRING) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named foreign key - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT fk1 REFERENCES parent(id), b STRING) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id"), + name = "fk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } } From 5317f28d5793a9017490225506ef8871b790513d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 28 Mar 2025 08:46:35 -0700 Subject: [PATCH 46/65] refactor alter --- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 +-- .../catalyst/expressions/constraints.scala | 1 + .../sql/catalyst/parser/AstBuilder.scala | 6 +- .../plans/logical/v2AlterTableCommands.scala | 8 +- .../AlterTableAddConstraintParseSuite.scala | 90 ++++++++++++++++++- 5 files changed, 98 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d48ea55a45a5..fb0721a0db6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1191,16 +1191,16 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString case _ => } - case addConstraint @ AddCheckConstraint(table: ResolvedTable, constraintExpr) => - if (!constraintExpr.resolved) { - constraintExpr.child.failAnalysis( + case AddConstraint(table: ResolvedTable, check: CheckConstraint) => + if (!check.resolved) { + check.child.failAnalysis( errorClass = "INVALID_CHECK_CONSTRAINT.UNRESOLVED", messageParameters = Map.empty ) } - if (!constraintExpr.deterministic) { - constraintExpr.child.failAnalysis( + if (!check.deterministic) { + check.child.failAnalysis( errorClass = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", messageParameters = Map.empty ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index fe68a7166746..237e67983cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.types.{DataType, StringType} trait TableConstraint { + // Convert to a data source v2 constraint def asConstraint: Constraint def withNameAndCharacteristic( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bff5c1228a7a..c3bf27ca165a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5425,10 +5425,8 @@ class AstBuilder extends DataTypeAstBuilder withOrigin(ctx) { val table = createUnresolvedTable( ctx.identifierReference, "ALTER TABLE ... ADD CONSTRAINT") - visitTableConstraintDefinition(ctx.tableConstraintDefinition()) match { - case c: CheckConstraint => - AddCheckConstraint(table, c) - } + val tableConstraint = visitTableConstraintDefinition(ctx.tableConstraintDefinition()) + AddConstraint(table, tableConstraint) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index 4a34c1654876..237fd8e7ef3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, UnresolvedException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ClusterBySpec -import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, Expression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{Expression, TableConstraint, Unevaluable} import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, TypeUtils} import org.apache.spark.sql.connector.catalog.{TableCatalog, TableChange} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -292,11 +292,11 @@ case class AlterTableCollation( /** * The logical plan of the ALTER TABLE ... ADD CONSTRAINT command. */ -case class AddCheckConstraint( +case class AddConstraint( table: LogicalPlan, - check: CheckConstraint) extends AlterTableCommand { + tableConstraint: TableConstraint) extends AlterTableCommand { override def changes: Seq[TableChange] = { - val constraint = check.asConstraint + val constraint = tableConstraint.asConstraint Seq(TableChange.addConstraint(constraint, constraint.enforced())) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index ce6e86842346..7fa0cafe1b26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedTable} -import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, GreaterThan, Literal} +import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, GreaterThan, Literal, PrimaryKeyConstraint} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.AddCheckConstraint +import org.apache.spark.sql.catalyst.plans.logical.AddConstraint class AlterTableAddConstraintParseSuite extends ConstraintParseSuiteBase { @@ -30,7 +30,7 @@ class AlterTableAddConstraintParseSuite extends ConstraintParseSuiteBase { |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) |""".stripMargin val parsed = parsePlan(sql) - val expected = AddCheckConstraint( + val expected = AddConstraint( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... ADD CONSTRAINT"), @@ -71,7 +71,7 @@ class AlterTableAddConstraintParseSuite extends ConstraintParseSuiteBase { |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $enforcedStr $relyStr |""".stripMargin val parsed = parsePlan(sql) - val expected = AddCheckConstraint( + val expected = AddConstraint( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... ADD CONSTRAINT"), @@ -117,4 +117,86 @@ class AlterTableAddConstraintParseSuite extends ConstraintParseSuiteBase { queryContext = Array(expectedContext)) } } + + test("Add primary key constraint") { + Seq(("", null), ("CONSTRAINT pk1", "pk1")).foreach { case (constraintName, expectedName) => + val sql = + s""" + |ALTER TABLE a.b.c ADD $constraintName PRIMARY KEY (id, name) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + PrimaryKeyConstraint( + name = expectedName, + columns = Seq("id", "name") + )) + comparePlans(parsed, expected) + } + } + + test("Add invalid primary key constraint name") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT pk-1 PRIMARY KEY (id) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) + } + + test("Add primary key constraint with empty columns") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY () + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) + } + + test("Add primary key constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + PrimaryKeyConstraint( + name = "pk1", + columns = Seq("id"), + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add primary key constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $characteristic1 $characteristic2" + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic1 $characteristic2", + start = 22, + stop = 54 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } } From 9ada207ed63620cf4e81dec27557b55c8b868d22 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 28 Mar 2025 10:18:00 -0700 Subject: [PATCH 47/65] add test case for unique --- .../AlterTableAddConstraintParseSuite.scala | 84 ++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index 7fa0cafe1b26..a0459b5ec990 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedTable} -import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, GreaterThan, Literal, PrimaryKeyConstraint} +import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, GreaterThan, Literal, PrimaryKeyConstraint, UniqueConstraint} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.AddConstraint @@ -199,4 +199,86 @@ class AlterTableAddConstraintParseSuite extends ConstraintParseSuiteBase { queryContext = Array(expectedContext)) } } + + test("Add unique constraint") { + Seq(("", null), ("CONSTRAINT uk1", "uk1")).foreach { case (constraintName, expectedName) => + val sql = + s""" + |ALTER TABLE a.b.c ADD $constraintName UNIQUE (email, username) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + UniqueConstraint( + name = expectedName, + columns = Seq("email", "username") + )) + comparePlans(parsed, expected) + } + } + + test("Add invalid unique constraint name") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT uk-1 UNIQUE (email) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) + } + + test("Add unique constraint with empty columns") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE () + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) + } + + test("Add unique constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE (email) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + UniqueConstraint( + name = "uk1", + columns = Seq("email"), + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add unique constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE (email) $characteristic1 $characteristic2" + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT uk1 UNIQUE (email) $characteristic1 $characteristic2", + start = 22, + stop = 52 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } } From 34514521a0e8c76ca35151a95ca88c2aac0bf2aa Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 28 Mar 2025 10:55:17 -0700 Subject: [PATCH 48/65] add test case for fk --- .../sql/catalyst/parser/AstBuilder.scala | 13 +-- .../AlterTableAddConstraintParseSuite.scala | 95 ++++++++++++++++++- 2 files changed, 101 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c3bf27ca165a..ba4e2c90339d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3967,7 +3967,7 @@ class AstBuilder extends DataTypeAstBuilder private def visitColumnConstraint( columnName: String, - ctx: ColumnConstraintContext): TableConstraint = { + ctx: ColumnConstraintContext): TableConstraint = withOrigin(ctx) { val columns = Seq(columnName) if (ctx.uniqueSpec() != null) { visitUniqueSpec(ctx.uniqueSpec(), columns) @@ -3990,11 +3990,12 @@ class AstBuilder extends DataTypeAstBuilder } } - override def visitReferenceSpec(ctx: ReferenceSpecContext): (Seq[String], Seq[String]) = { - val tableId = visitMultipartIdentifier(ctx.multipartIdentifier()) - val refColumns = visitIdentifierList(ctx.parentColumns) - (tableId, refColumns) - } + override def visitReferenceSpec(ctx: ReferenceSpecContext): (Seq[String], Seq[String]) = + withOrigin(ctx) { + val tableId = visitMultipartIdentifier(ctx.multipartIdentifier()) + val refColumns = visitIdentifierList(ctx.parentColumns) + (tableId, refColumns) + } /** * Create a location string. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala index a0459b5ec990..5250a55d42ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedTable} -import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, GreaterThan, Literal, PrimaryKeyConstraint, UniqueConstraint} +import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, ForeignKeyConstraint, GreaterThan, Literal, PrimaryKeyConstraint, UniqueConstraint} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.AddConstraint @@ -281,4 +281,97 @@ class AlterTableAddConstraintParseSuite extends ConstraintParseSuiteBase { queryContext = Array(expectedContext)) } } +test("Add foreign key constraint") { + Seq(("", null), ("CONSTRAINT fk1", "fk1")).foreach { case (constraintName, expectedName) => + val sql = + s""" + |ALTER TABLE orders ADD $constraintName FOREIGN KEY (customer_id) + |REFERENCES customers (id) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("orders"), + "ALTER TABLE ... ADD CONSTRAINT"), + ForeignKeyConstraint( + name = expectedName, + childColumns = Seq("customer_id"), + parentTableId = Seq("customers"), + parentColumns = Seq("id") + )) + comparePlans(parsed, expected) + } + } + + test("Add invalid foreign key constraint name") { + val sql = + """ + |ALTER TABLE orders ADD CONSTRAINT fk-1 FOREIGN KEY (customer_id) + |REFERENCES customers (id) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) + } + + test("Add foreign key constraint with empty columns") { + val sql = + """ + |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY () + |REFERENCES customers (id) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) + } + + test("Add foreign key constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY (customer_id) + |REFERENCES customers (id) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("orders"), + "ALTER TABLE ... ADD CONSTRAINT"), + ForeignKeyConstraint( + name = "fk1", + childColumns = Seq("customer_id"), + parentTableId = Seq("customers"), + parentColumns = Seq("id"), + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add foreign key constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s""" + |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY (customer_id) + |REFERENCES customers (id) $characteristic1 $characteristic2 + |""".stripMargin + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT fk1 FOREIGN KEY (customer_id)\nREFERENCES customers (id) " + + s"$characteristic1 $characteristic2", + start = 24, + stop = 91 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } } From 657fce1e7204fa9767b437b7d452d0c5598e967f Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 28 Mar 2025 11:19:04 -0700 Subject: [PATCH 49/65] remove legacy CreateTableConstraintSuite.scala --- .../v2/CreateTableConstraintSuite.scala | 86 ------------------- 1 file changed, 86 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala deleted file mode 100644 index bd78eb4d3c46..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableConstraintSuite.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.command.v2 - -import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.connector.catalog.constraints.Check -import org.apache.spark.sql.execution.command.DDLCommandTestUtils - -class CreateTableConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { - override protected def command: String = "CREATE TABLE .. CONSTRAINT" - - test("Create table with one check constraint") { - withNamespaceAndTable("ns", "tbl", catalog) { t => - sql( - s""" - |CREATE TABLE $t (id bigint, data string) $defaultUsing - | CONSTRAINT c1 CHECK (id > 0)""".stripMargin) - val constraints = loadTable(catalog, "ns", "tbl").constraints - assert(constraints.length == 1) - assert(constraints.head.isInstanceOf[Check]) - val constraint = constraints.head.asInstanceOf[Check] - - assert(constraint.name == "c1") - assert(constraint.sql == "id>0") - assert(constraint.predicate().toString() == "id > CAST(0 AS long)") - } - } - - test("Create table with two check constraints") { - withNamespaceAndTable("ns", "tbl", catalog) { t => - sql( - s""" - |CREATE TABLE $t (id bigint, data string) $defaultUsing - | CONSTRAINT c1 CHECK (id > 0) - | CONSTRAINT c2 CHECK (data = 'foo')""".stripMargin) - val constraints = loadTable(catalog, "ns", "tbl").constraints - assert(constraints.length == 2) - assert(constraints.head.isInstanceOf[Check]) - val constraint = constraints.head.asInstanceOf[Check] - - assert(constraint.name == "c1") - assert(constraint.sql == "id>0") - assert(constraint.predicate().toString() == "id > CAST(0 AS long)") - - assert(constraints(1).isInstanceOf[Check]) - val constraint2 = constraints(1).asInstanceOf[Check] - - assert(constraint2.name == "c2") - assert(constraint2.sql == "data='foo'") - assert(constraint2.predicate().toString() == "data = 'foo'") - } - } - - test("Create table with UnresolvedAttribute in check constraint") { - withNamespaceAndTable("ns", "tbl", catalog) { t => - val query = - s""" - |CREATE TABLE $t (id bigint, data string) $defaultUsing - | CONSTRAINT c2 CHECK (abc = 'foo')""".stripMargin - val e = intercept[AnalysisException] { - sql(query) - } - checkError( - exception = e, - condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`abc`", "proposal" -> "`id`, `data`"), - sqlState = "42703", - context = ExpectedContext("abc", 89, 91)) // UnresolvedAttribute abc - } - } -} From 1b5cc2f9580da142d0b625be9b3a739b964098a7 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 28 Mar 2025 15:29:27 -0700 Subject: [PATCH 50/65] change default value of rely; add PrimaryKeyConstraintSuite --- .../catalyst/expressions/constraints.scala | 6 +- .../v2/PrimaryKeyConstraintSuite.scala | 103 ++++++++++++++++++ 2 files changed, 106 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index 237e67983cb5..4fecb1938b21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -122,7 +122,7 @@ case class PrimaryKeyConstraint( override def defaultName: String = "pk" override def defaultConstraintCharacteristic: ConstraintCharacteristic = - ConstraintCharacteristic(enforced = Some(false), rely = Some(true)) + ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) } case class UniqueConstraint( @@ -154,7 +154,7 @@ case class UniqueConstraint( messageParameters = Map.empty) override def defaultConstraintCharacteristic: ConstraintCharacteristic = - ConstraintCharacteristic(enforced = Some(false), rely = Some(true)) + ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) } case class ForeignKeyConstraint( @@ -189,5 +189,5 @@ case class ForeignKeyConstraint( override def defaultName: String = "fk" override def defaultConstraintCharacteristic: ConstraintCharacteristic = - ConstraintCharacteristic(enforced = Some(false), rely = Some(true)) + ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala new file mode 100644 index 000000000000..731754062467 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + + +class PrimaryKeyConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" + + test("Add primary key constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT pk1 PRIMARY KEY (id) $characteristic") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "pk1") + assert(constraint.toDDL == s"CONSTRAINT pk1 PRIMARY KEY (id) $expectedDDL") + } + } + } + + test("Create table with primary key constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + val constraintStr = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic" + sql(s"CREATE TABLE $t (id bigint, data string, $constraintStr) $defaultUsing") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "pk1") + assert(constraint.toDDL == s"CONSTRAINT pk1 PRIMARY KEY (id) $expectedDDL") + } + } + } + + test("Add duplicated primary key constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT pk1 PRIMARY KEY (id)") + // Constraint names are case-insensitive + Seq("pk1", "PK1").foreach { name => + val error = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD CONSTRAINT $name PRIMARY KEY (id)") + } + checkError( + exception = error, + condition = "CONSTRAINT_ALREADY_EXISTS", + sqlState = "42710", + parameters = Map("constraintName" -> "pk1", + "oldConstraint" -> "CONSTRAINT pk1 PRIMARY KEY (id) NOT ENFORCED UNVALIDATED NORELY") + ) + } + } + } + + test("Add primary key constraint with multiple columns") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id1 bigint, id2 bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT pk1 PRIMARY KEY (id1, id2)") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "pk1") + assert(constraint.toDDL == + "CONSTRAINT pk1 PRIMARY KEY (id1, id2) NOT ENFORCED UNVALIDATED NORELY") + } + } + + val validConstraintCharacteristics = Seq( + ("", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("ENFORCED RELY", "ENFORCED UNVALIDATED RELY"), + ("RELY ENFORCED", "ENFORCED UNVALIDATED RELY"), + ("RELY", "NOT ENFORCED UNVALIDATED RELY") + ) +} From 44bbd781dd39075d08f30eef943397f6ed7f0a1f Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 28 Mar 2025 15:49:07 -0700 Subject: [PATCH 51/65] change check constraint to NORELY --- .../spark/sql/catalyst/expressions/constraints.scala | 2 +- .../execution/command/v2/CheckConstraintSuite.scala | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index 4fecb1938b21..a9af7a67feef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -89,7 +89,7 @@ case class CheckConstraint( messageParameters = Map.empty) override def defaultConstraintCharacteristic: ConstraintCharacteristic = - ConstraintCharacteristic(enforced = Some(true), rely = Some(true)) + ConstraintCharacteristic(enforced = Some(true), rely = Some(false)) override def sql: String = s"CONSTRAINT $name CHECK ($condition)" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index b6d60912f5f5..97cc68ede91d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -89,20 +89,22 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma val constraint = getCheckConstraint(table) assert(constraint.name() == "c1") assert(constraint.toDDL == - "CONSTRAINT c1 CHECK from_json(j, 'a INT').a > 1 ENFORCED VALID RELY") + "CONSTRAINT c1 CHECK from_json(j, 'a INT').a > 1 ENFORCED VALID NORELY") assert(constraint.sql() == "from_json(j, 'a INT').a > 1") assert(constraint.predicate() == null) } } val validConstraintCharacteristics = Seq( - ("", "ENFORCED VALID RELY"), - ("NOT ENFORCED", "NOT ENFORCED VALID RELY"), + ("", "ENFORCED VALID NORELY"), + ("NOT ENFORCED", "NOT ENFORCED VALID NORELY"), ("NOT ENFORCED NORELY", "NOT ENFORCED VALID NORELY"), ("NORELY NOT ENFORCED", "NOT ENFORCED VALID NORELY"), ("NORELY", "ENFORCED VALID NORELY"), ("NOT ENFORCED RELY", "NOT ENFORCED VALID RELY"), ("RELY NOT ENFORCED", "NOT ENFORCED VALID RELY"), + ("NOT ENFORCED RELY", "NOT ENFORCED VALID RELY"), + ("RELY NOT ENFORCED", "NOT ENFORCED VALID RELY"), ("RELY", "ENFORCED VALID RELY") ) @@ -150,7 +152,7 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma condition = "CONSTRAINT_ALREADY_EXISTS", sqlState = "42710", parameters = Map("constraintName" -> "abc", - "oldConstraint" -> "CONSTRAINT abc CHECK id > 0 ENFORCED VALID RELY") + "oldConstraint" -> "CONSTRAINT abc CHECK id > 0 ENFORCED VALID NORELY") ) } } From 87d22060607555605968f43a3da42c24ab457397 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 28 Mar 2025 16:53:59 -0700 Subject: [PATCH 52/65] change the valid status of check --- .../catalyst/analysis/ResolveTableSpec.scala | 2 +- .../catalyst/expressions/constraints.scala | 19 ++++++++---- .../plans/logical/v2AlterTableCommands.scala | 2 +- .../command/v2/CheckConstraintSuite.scala | 31 ++++++++++--------- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index e4ce620717b5..48bd9dbaa0be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -131,7 +131,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] { collation = u.collation, serde = u.serde, external = u.external, - constraints = newConstraints.map(_.asConstraint)) + constraints = newConstraints.map(_.asConstraint(isCreateTable = true))) withNewSpec(newTableSpec) case _ => input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index a9af7a67feef..ff47c549ee82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.{DataType, StringType} trait TableConstraint { // Convert to a data source v2 constraint - def asConstraint: Constraint + def asConstraint(isCreateTable: Boolean): Constraint def withNameAndCharacteristic( name: String, @@ -60,17 +60,24 @@ case class CheckConstraint( with Unevaluable with TableConstraint { - def asConstraint: Constraint = { + def asConstraint(isCreateTable: Boolean): Constraint = { val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull val (rely, enforced) = getCharacteristicValues val constraintName = if (name == null) defaultName else name + // The validation status is set to UNVALIDATED for create table and + // VALID for alter table. + val validateStatus = if (isCreateTable) { + Constraint.ValidationStatus.UNVALIDATED + } else { + Constraint.ValidationStatus.VALID + } Constraint .check(constraintName) .sql(condition) .predicate(predicate) .rely(rely) .enforced(enforced) - .validationStatus(Constraint.ValidationStatus.VALID) + .validationStatus(validateStatus) .build() } @@ -102,7 +109,7 @@ case class PrimaryKeyConstraint( override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) extends TableConstraint { - override def asConstraint: Constraint = { + override def asConstraint(isCreateTable: Boolean): Constraint = { val (rely, enforced) = getCharacteristicValues val constraintName = if (name == null) defaultName else name Constraint @@ -131,7 +138,7 @@ case class UniqueConstraint( override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) extends TableConstraint { - override def asConstraint: Constraint = { + override def asConstraint(isCreateTable: Boolean): Constraint = { val (rely, enforced) = getCharacteristicValues val constraintName = if (name == null) defaultName else name Constraint @@ -166,7 +173,7 @@ case class ForeignKeyConstraint( extends TableConstraint { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - override def asConstraint: Constraint = { + override def asConstraint(isCreateTable: Boolean): Constraint = { val (rely, enforced) = getCharacteristicValues val constraintName = if (name == null) defaultName else name Constraint diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index 237fd8e7ef3b..845c6a64a569 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -296,7 +296,7 @@ case class AddConstraint( table: LogicalPlan, tableConstraint: TableConstraint) extends AlterTableCommand { override def changes: Seq[TableChange] = { - val constraint = tableConstraint.asConstraint + val constraint = tableConstraint.asConstraint(isCreateTable = false) Seq(TableChange.addConstraint(constraint, constraint.enforced())) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index 97cc68ede91d..0d35fd4c48b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -95,21 +95,24 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma } } - val validConstraintCharacteristics = Seq( - ("", "ENFORCED VALID NORELY"), - ("NOT ENFORCED", "NOT ENFORCED VALID NORELY"), - ("NOT ENFORCED NORELY", "NOT ENFORCED VALID NORELY"), - ("NORELY NOT ENFORCED", "NOT ENFORCED VALID NORELY"), - ("NORELY", "ENFORCED VALID NORELY"), - ("NOT ENFORCED RELY", "NOT ENFORCED VALID RELY"), - ("RELY NOT ENFORCED", "NOT ENFORCED VALID RELY"), - ("NOT ENFORCED RELY", "NOT ENFORCED VALID RELY"), - ("RELY NOT ENFORCED", "NOT ENFORCED VALID RELY"), - ("RELY", "ENFORCED VALID RELY") - ) + def getConstraintCharacteristics(isCreateTable: Boolean): Seq[(String, String)] = { + val validStatus = if (isCreateTable) "UNVALIDATED" else "VALID" + Seq( + ("", s"ENFORCED $validStatus NORELY"), + ("NOT ENFORCED", s"NOT ENFORCED $validStatus NORELY"), + ("NOT ENFORCED NORELY", s"NOT ENFORCED $validStatus NORELY"), + ("NORELY NOT ENFORCED", s"NOT ENFORCED $validStatus NORELY"), + ("NORELY", s"ENFORCED $validStatus NORELY"), + ("NOT ENFORCED RELY", s"NOT ENFORCED $validStatus RELY"), + ("RELY NOT ENFORCED", s"NOT ENFORCED $validStatus RELY"), + ("NOT ENFORCED RELY", s"NOT ENFORCED $validStatus RELY"), + ("RELY NOT ENFORCED", s"NOT ENFORCED $validStatus RELY"), + ("RELY", s"ENFORCED $validStatus RELY") + ) + } test("Create table with check constraint") { - validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + getConstraintCharacteristics(true).foreach { case (characteristic, expectedDDL) => withNamespaceAndTable("ns", "tbl", catalog) { t => val constraintStr = s"CONSTRAINT c1 CHECK (id > 0) $characteristic" sql(s"CREATE TABLE $t (id bigint, data string, $constraintStr) $defaultUsing") @@ -122,7 +125,7 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma } test("Alter table add check constraint") { - validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + getConstraintCharacteristics(false).foreach { case (characteristic, expectedDDL) => withNamespaceAndTable("ns", "tbl", catalog) { t => sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) From 258d10d97a6a910a1cce1a75a882c8a39f0f9f45 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 31 Mar 2025 11:52:39 -0700 Subject: [PATCH 53/65] disallow enforce in pk/fk/unique --- .../resources/error/error-conditions.json | 6 +++ .../catalyst/expressions/constraints.scala | 45 ++++++++++++++++--- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../command/CheckConstraintParseSuite.scala | 5 +++ 4 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index b042bb134d76..64b71af92552 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5501,6 +5501,12 @@ }, "sqlState" : "0A000" }, + "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC": { + "message": [ + "Constraint characteristic '' is not supported for constraint type ''." + ], + "sqlState": "0A000" + }, "UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY" : { "message" : [ "Unsupported data source type for direct query on files: " diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index ff47c549ee82..d233ae6d4120 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.catalyst.expressions +import org.antlr.v4.runtime.ParserRuleContext + import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.expressions.FieldReference @@ -28,7 +31,8 @@ trait TableConstraint { def withNameAndCharacteristic( name: String, - c: ConstraintCharacteristic): TableConstraint + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint def name: String @@ -86,7 +90,8 @@ case class CheckConstraint( override def withNameAndCharacteristic( name: String, - c: ConstraintCharacteristic): TableConstraint = { + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint = { copy(name = name, characteristic = c) } @@ -122,7 +127,17 @@ case class PrimaryKeyConstraint( override def withNameAndCharacteristic( name: String, - c: ConstraintCharacteristic): TableConstraint = { + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint = { + if (c.enforced.contains(true)) { + throw new ParseException( + errorClass = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + messageParameters = Map( + "characteristic" -> "ENFORCED", + "constraintType" -> "PRIMARY KEY"), + ctx = ctx + ) + } copy(name = name, characteristic = c) } @@ -151,7 +166,17 @@ case class UniqueConstraint( override def withNameAndCharacteristic( name: String, - c: ConstraintCharacteristic): TableConstraint = { + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint = { + if (c.enforced.contains(true)) { + throw new ParseException( + errorClass = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + messageParameters = Map( + "characteristic" -> "ENFORCED", + "constraintType" -> "UNIQUE"), + ctx = ctx + ) + } copy(name = name, characteristic = c) } @@ -189,7 +214,17 @@ case class ForeignKeyConstraint( override def withNameAndCharacteristic( name: String, - c: ConstraintCharacteristic): TableConstraint = { + c: ConstraintCharacteristic, + ctx: ParserRuleContext): TableConstraint = { + if (c.enforced.contains(true)) { + throw new ParseException( + errorClass = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + messageParameters = Map( + "characteristic" -> "ENFORCED", + "constraintType" -> "FOREIGN KEY"), + ctx = ctx + ) + } copy(name = name, characteristic = c) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ba4e2c90339d..fa10d96b75d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3961,7 +3961,7 @@ class AstBuilder extends DataTypeAstBuilder visitConstraintCharacteristics(ctx.constraintCharacteristic().asScala.toSeq) val expr = visitColumnConstraint(columnName, ctx.columnConstraint()) - expr.withNameAndCharacteristic(name, constraintCharacteristic) + expr.withNameAndCharacteristic(name, constraintCharacteristic, ctx) } } @@ -5358,7 +5358,7 @@ class AstBuilder extends DataTypeAstBuilder val expr = visitTableConstraint(ctx.tableConstraint()).asInstanceOf[TableConstraint] - expr.withNameAndCharacteristic(name, constraintCharacteristic) + expr.withNameAndCharacteristic(name, constraintCharacteristic, ctx) } override def visitCheckConstraint(ctx: CheckConstraintContext): CheckConstraint = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala new file mode 100644 index 000000000000..eebc6f84ea21 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala @@ -0,0 +1,5 @@ +package org.apache.spark.sql.execution.command + +class CheckConstraintParseSuite { + +} From 05bff35caeaad4e1513be043153046a8b38c6a07 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 31 Mar 2025 13:50:04 -0700 Subject: [PATCH 54/65] refactor tests --- .../AlterTableAddConstraintParseSuite.scala | 377 ------------------ .../command/CheckConstraintParseSuite.scala | 174 +++++++- .../command/ConstraintParseSuiteBase.scala | 38 +- .../CreateTableConstraintParseSuite.scala | 274 ------------- .../ForeignKeyConstraintParseSuite.scala | 244 ++++++++++++ .../PrimaryKeyConstraintParseSuite.scala | 244 ++++++++++++ .../command/UniqueConstraintParseSuite.scala | 236 +++++++++++ .../v2/PrimaryKeyConstraintSuite.scala | 21 +- 8 files changed, 938 insertions(+), 670 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala deleted file mode 100644 index 5250a55d42ba..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddConstraintParseSuite.scala +++ /dev/null @@ -1,377 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.command - -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedTable} -import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, ForeignKeyConstraint, GreaterThan, Literal, PrimaryKeyConstraint, UniqueConstraint} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.AddConstraint - -class AlterTableAddConstraintParseSuite extends ConstraintParseSuiteBase { - - test("Add check constraint") { - val sql = - """ - |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) - |""".stripMargin - val parsed = parsePlan(sql) - val expected = AddConstraint( - UnresolvedTable( - Seq("a", "b", "c"), - "ALTER TABLE ... ADD CONSTRAINT"), - CheckConstraint( - child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), - condition = "d > 0", - name = "c1" - )) - comparePlans(parsed, expected) - } - - test("Add invalid check constraint name") { - val sql = - """ - |ALTER TABLE a.b.c ADD CONSTRAINT c1-c3 CHECK (d > 0) - |""".stripMargin - val e = intercept[ParseException] { - parsePlan(sql) - } - checkError(e, "INVALID_IDENTIFIER", "42602", Map("ident" -> "c1-c3")) - } - - test("Add invalid check constraint expression") { - val sql = - """ - |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d >) - |""".stripMargin - val msg = intercept[ParseException] { - parsePlan(sql) - }.getMessage - assert(msg.contains("Syntax error at or near ')'")) - } - - test("Add check constraint with valid characteristic") { - validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => - val sql = - s""" - |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $enforcedStr $relyStr - |""".stripMargin - val parsed = parsePlan(sql) - val expected = AddConstraint( - UnresolvedTable( - Seq("a", "b", "c"), - "ALTER TABLE ... ADD CONSTRAINT"), - CheckConstraint( - child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), - condition = "d > 0", - name = "c1", - characteristic = characteristic - )) - comparePlans(parsed, expected) - } - } - - - test("Add check constraint with invalid characteristic") { - val combinations = Seq( - ("ENFORCED", "ENFORCED"), - ("ENFORCED", "NOT ENFORCED"), - ("NOT ENFORCED", "ENFORCED"), - ("NOT ENFORCED", "NOT ENFORCED"), - ("RELY", "RELY"), - ("RELY", "NORELY"), - ("NORELY", "RELY"), - ("NORELY", "NORELY") - ) - - combinations.foreach { case (characteristic1, characteristic2) => - val sql = - s"ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $characteristic1 $characteristic2" - - val e = intercept[ParseException] { - parsePlan(sql) - } - val expectedContext = ExpectedContext( - fragment = s"CONSTRAINT c1 CHECK (d > 0) $characteristic1 $characteristic2", - start = 22, - stop = 50 + characteristic1.length + characteristic2.length - ) - checkError( - exception = e, - condition = "INVALID_CONSTRAINT_CHARACTERISTICS", - parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), - queryContext = Array(expectedContext)) - } - } - - test("Add primary key constraint") { - Seq(("", null), ("CONSTRAINT pk1", "pk1")).foreach { case (constraintName, expectedName) => - val sql = - s""" - |ALTER TABLE a.b.c ADD $constraintName PRIMARY KEY (id, name) - |""".stripMargin - val parsed = parsePlan(sql) - val expected = AddConstraint( - UnresolvedTable( - Seq("a", "b", "c"), - "ALTER TABLE ... ADD CONSTRAINT"), - PrimaryKeyConstraint( - name = expectedName, - columns = Seq("id", "name") - )) - comparePlans(parsed, expected) - } - } - - test("Add invalid primary key constraint name") { - val sql = - """ - |ALTER TABLE a.b.c ADD CONSTRAINT pk-1 PRIMARY KEY (id) - |""".stripMargin - val e = intercept[ParseException] { - parsePlan(sql) - } - checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) - } - - test("Add primary key constraint with empty columns") { - val sql = - """ - |ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY () - |""".stripMargin - val e = intercept[ParseException] { - parsePlan(sql) - } - checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) - } - - test("Add primary key constraint with valid characteristic") { - validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => - val sql = - s""" - |ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $enforcedStr $relyStr - |""".stripMargin - val parsed = parsePlan(sql) - val expected = AddConstraint( - UnresolvedTable( - Seq("a", "b", "c"), - "ALTER TABLE ... ADD CONSTRAINT"), - PrimaryKeyConstraint( - name = "pk1", - columns = Seq("id"), - characteristic = characteristic - )) - comparePlans(parsed, expected) - } - } - - test("Add primary key constraint with invalid characteristic") { - invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => - val sql = - s"ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $characteristic1 $characteristic2" - - val e = intercept[ParseException] { - parsePlan(sql) - } - val expectedContext = ExpectedContext( - fragment = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic1 $characteristic2", - start = 22, - stop = 54 + characteristic1.length + characteristic2.length - ) - checkError( - exception = e, - condition = "INVALID_CONSTRAINT_CHARACTERISTICS", - parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), - queryContext = Array(expectedContext)) - } - } - - test("Add unique constraint") { - Seq(("", null), ("CONSTRAINT uk1", "uk1")).foreach { case (constraintName, expectedName) => - val sql = - s""" - |ALTER TABLE a.b.c ADD $constraintName UNIQUE (email, username) - |""".stripMargin - val parsed = parsePlan(sql) - val expected = AddConstraint( - UnresolvedTable( - Seq("a", "b", "c"), - "ALTER TABLE ... ADD CONSTRAINT"), - UniqueConstraint( - name = expectedName, - columns = Seq("email", "username") - )) - comparePlans(parsed, expected) - } - } - - test("Add invalid unique constraint name") { - val sql = - """ - |ALTER TABLE a.b.c ADD CONSTRAINT uk-1 UNIQUE (email) - |""".stripMargin - val e = intercept[ParseException] { - parsePlan(sql) - } - checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) - } - - test("Add unique constraint with empty columns") { - val sql = - """ - |ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE () - |""".stripMargin - val e = intercept[ParseException] { - parsePlan(sql) - } - checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) - } - - test("Add unique constraint with valid characteristic") { - validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => - val sql = - s""" - |ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE (email) $enforcedStr $relyStr - |""".stripMargin - val parsed = parsePlan(sql) - val expected = AddConstraint( - UnresolvedTable( - Seq("a", "b", "c"), - "ALTER TABLE ... ADD CONSTRAINT"), - UniqueConstraint( - name = "uk1", - columns = Seq("email"), - characteristic = characteristic - )) - comparePlans(parsed, expected) - } - } - - test("Add unique constraint with invalid characteristic") { - invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => - val sql = - s"ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE (email) $characteristic1 $characteristic2" - - val e = intercept[ParseException] { - parsePlan(sql) - } - val expectedContext = ExpectedContext( - fragment = s"CONSTRAINT uk1 UNIQUE (email) $characteristic1 $characteristic2", - start = 22, - stop = 52 + characteristic1.length + characteristic2.length - ) - checkError( - exception = e, - condition = "INVALID_CONSTRAINT_CHARACTERISTICS", - parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), - queryContext = Array(expectedContext)) - } - } -test("Add foreign key constraint") { - Seq(("", null), ("CONSTRAINT fk1", "fk1")).foreach { case (constraintName, expectedName) => - val sql = - s""" - |ALTER TABLE orders ADD $constraintName FOREIGN KEY (customer_id) - |REFERENCES customers (id) - |""".stripMargin - val parsed = parsePlan(sql) - val expected = AddConstraint( - UnresolvedTable( - Seq("orders"), - "ALTER TABLE ... ADD CONSTRAINT"), - ForeignKeyConstraint( - name = expectedName, - childColumns = Seq("customer_id"), - parentTableId = Seq("customers"), - parentColumns = Seq("id") - )) - comparePlans(parsed, expected) - } - } - - test("Add invalid foreign key constraint name") { - val sql = - """ - |ALTER TABLE orders ADD CONSTRAINT fk-1 FOREIGN KEY (customer_id) - |REFERENCES customers (id) - |""".stripMargin - val e = intercept[ParseException] { - parsePlan(sql) - } - checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) - } - - test("Add foreign key constraint with empty columns") { - val sql = - """ - |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY () - |REFERENCES customers (id) - |""".stripMargin - val e = intercept[ParseException] { - parsePlan(sql) - } - checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) - } - - test("Add foreign key constraint with valid characteristic") { - validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => - val sql = - s""" - |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY (customer_id) - |REFERENCES customers (id) $enforcedStr $relyStr - |""".stripMargin - val parsed = parsePlan(sql) - val expected = AddConstraint( - UnresolvedTable( - Seq("orders"), - "ALTER TABLE ... ADD CONSTRAINT"), - ForeignKeyConstraint( - name = "fk1", - childColumns = Seq("customer_id"), - parentTableId = Seq("customers"), - parentColumns = Seq("id"), - characteristic = characteristic - )) - comparePlans(parsed, expected) - } - } - - test("Add foreign key constraint with invalid characteristic") { - invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => - val sql = - s""" - |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY (customer_id) - |REFERENCES customers (id) $characteristic1 $characteristic2 - |""".stripMargin - - val e = intercept[ParseException] { - parsePlan(sql) - } - val expectedContext = ExpectedContext( - fragment = s"CONSTRAINT fk1 FOREIGN KEY (customer_id)\nREFERENCES customers (id) " + - s"$characteristic1 $characteristic2", - start = 24, - stop = 91 + characteristic1.length + characteristic2.length - ) - checkError( - exception = e, - condition = "INVALID_CONSTRAINT_CHARACTERISTICS", - parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), - queryContext = Array(expectedContext)) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala index eebc6f84ea21..58028ccb115e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala @@ -1,5 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution.command -class CheckConstraintParseSuite { +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedTable} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.{AddConstraint, ColumnDefinition} +import org.apache.spark.sql.types.StringType + +class CheckConstraintParseSuite extends ConstraintParseSuiteBase { + override val validConstraintCharacteristics = + super.validConstraintCharacteristics ++ super.enforcedConstraintCharacteristics + + test("Create table with one check constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0)) USING parquet" + val constraint = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a > 0", + name = "c1") + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with two check constraints - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0), " + + "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" + val constraint1 = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a > 0", + name = "c1") + val constraint2 = CheckConstraint( + child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), + condition = "b = 'foo'", + name = "c2") + val constraints = Seq(constraint1, constraint2) + verifyConstraints(sql, constraints) + } + + test("Create table with valid characteristic - table level") { + validConstraintCharacteristics.foreach { + case (enforcedStr, relyStr, characteristic) => + val sql = s"CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0) " + + s"$enforcedStr $relyStr) USING parquet" + val constraint = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a > 0", + name = "c1", + characteristic = characteristic) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + } + + test("Create table with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2", + start = 33, + stop = 61 + characteristic1.length + characteristic2.length + ) + checkError( + exception = intercept[ParseException] { + parsePlan(s"CREATE TABLE t (a INT, b STRING, $constraintStr ) USING parquet") + }, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("Create table with column 'constraint'") { + val sql = "CREATE TABLE t (constraint STRING) USING parquet" + val columns = Seq(ColumnDefinition("constraint", StringType)) + val expected = createExpectedPlan(columns, Seq.empty) + comparePlans(parsePlan(sql), expected) + } + + test("Add check constraint") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + CheckConstraint( + child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), + condition = "d > 0", + name = "c1" + )) + comparePlans(parsed, expected) + } + + test("Add invalid check constraint name") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1-c3 CHECK (d > 0) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "INVALID_IDENTIFIER", "42602", Map("ident" -> "c1-c3")) + } + + test("Add invalid check constraint expression") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d >) + |""".stripMargin + val msg = intercept[ParseException] { + parsePlan(sql) + }.getMessage + assert(msg.contains("Syntax error at or near ')'")) + } + + test("Add check constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + CheckConstraint( + child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), + condition = "d > 0", + name = "c1", + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add check constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) $characteristic1 $characteristic2" + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (d > 0) $characteristic1 $characteristic2", + start = 22, + stop = 50 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala index c870399453af..e876207138c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala @@ -16,25 +16,33 @@ */ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.catalyst.analysis.AnalysisTest -import org.apache.spark.sql.catalyst.expressions.ConstraintCharacteristic +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedIdentifier} +import org.apache.spark.sql.catalyst.expressions.{ConstraintCharacteristic, TableConstraint} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType} abstract class ConstraintParseSuiteBase extends AnalysisTest with SharedSparkSession { - protected val validConstraintCharacteristics = Seq( + protected def validConstraintCharacteristics = Seq( ("", "", ConstraintCharacteristic(enforced = None, rely = None)), - ("ENFORCED", "", ConstraintCharacteristic(enforced = Some(true), rely = None)), ("NOT ENFORCED", "", ConstraintCharacteristic(enforced = Some(false), rely = None)), ("", "RELY", ConstraintCharacteristic(enforced = None, rely = Some(true))), ("", "NORELY", ConstraintCharacteristic(enforced = None, rely = Some(false))), - ("ENFORCED", "RELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(true))), - ("ENFORCED", "NORELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(false))), ("NOT ENFORCED", "RELY", ConstraintCharacteristic(enforced = Some(false), rely = Some(true))), ("NOT ENFORCED", "NORELY", ConstraintCharacteristic(enforced = Some(false), rely = Some(false))) ) + protected def enforcedConstraintCharacteristics = Seq( + ("ENFORCED", "", ConstraintCharacteristic(enforced = Some(true), rely = None)), + ("ENFORCED", "RELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(true))), + ("ENFORCED", "NORELY", ConstraintCharacteristic(enforced = Some(true), rely = Some(false))), + ("RELY", "ENFORCED", ConstraintCharacteristic(enforced = Some(true), rely = Some(true))), + ("NORELY", "ENFORCED", ConstraintCharacteristic(enforced = Some(true), rely = Some(false))) + ) + protected val invalidConstraintCharacteristics = Seq( ("ENFORCED", "ENFORCED"), ("ENFORCED", "NOT ENFORCED"), @@ -46,5 +54,23 @@ abstract class ConstraintParseSuiteBase extends AnalysisTest with SharedSparkSes ("NORELY", "NORELY") ) + protected def createExpectedPlan( + columns: Seq[ColumnDefinition], + constraints: Seq[TableConstraint]): CreateTable = { + val tableId = UnresolvedIdentifier(Seq("t")) + val tableSpec = UnresolvedTableSpec( + Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), + None, None, None, None, false, constraints) + CreateTable(tableId, columns, Seq.empty, tableSpec, false) + } + protected def verifyConstraints(sql: String, constraints: Seq[TableConstraint]): Unit = { + val parsed = parsePlan(sql) + val columns = Seq( + ColumnDefinition("a", IntegerType), + ColumnDefinition("b", StringType) + ) + val expected = createExpectedPlan(columns = columns, constraints = constraints) + comparePlans(parsed, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala deleted file mode 100644 index 4c003c1b79b2..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableConstraintParseSuite.scala +++ /dev/null @@ -1,274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.command - -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} -import org.apache.spark.sql.types.{IntegerType, StringType} - -class CreateTableConstraintParseSuite extends ConstraintParseSuiteBase { - - def createExpectedPlan( - columns: Seq[ColumnDefinition], - constraints: Seq[TableConstraint]): CreateTable = { - val tableId = UnresolvedIdentifier(Seq("t")) - val tableSpec = UnresolvedTableSpec( - Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), - None, None, None, None, false, constraints) - CreateTable(tableId, columns, Seq.empty, tableSpec, false) - } - - def verifyConstraints(sql: String, constraints: Seq[TableConstraint]): Unit = { - val parsed = parsePlan(sql) - val columns = Seq( - ColumnDefinition("a", IntegerType), - ColumnDefinition("b", StringType) - ) - val expected = createExpectedPlan(columns = columns, constraints = constraints) - comparePlans(parsed, expected) - } - - test("Create table with one check constraint - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0)) USING parquet" - val constraint = CheckConstraint( - child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a > 0", - name = "c1") - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with two check constraints - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0), " + - "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" - val constraint1 = CheckConstraint( - child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a > 0", - name = "c1") - val constraint2 = CheckConstraint( - child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), - condition = "b = 'foo'", - name = "c2") - val constraints = Seq(constraint1, constraint2) - verifyConstraints(sql, constraints) - } - - test("Create table with valid characteristic - table level") { - validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => - val sql = s"CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0) " + - s"$enforcedStr $relyStr) USING parquet" - val constraint = CheckConstraint( - child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a > 0", - name = "c1", - characteristic = characteristic) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - } - - test("Create table with invalid characteristic") { - invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => - val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" - val expectedContext = ExpectedContext( - fragment = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2", - start = 33, - stop = 61 + characteristic1.length + characteristic2.length - ) - checkError( - exception = intercept[ParseException] { - parsePlan(s"CREATE TABLE t (a INT, b STRING, $constraintStr ) USING parquet") - }, - condition = "INVALID_CONSTRAINT_CHARACTERISTICS", - parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), - queryContext = Array(expectedContext)) - } - } - - test("Create table with column 'constraint'") { - val sql = "CREATE TABLE t (constraint STRING) USING parquet" - val columns = Seq(ColumnDefinition("constraint", StringType)) - val expected = createExpectedPlan(columns, Seq.empty) - comparePlans(parsePlan(sql), expected) - } - - - test("Create table with primary key - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a)) USING parquet" - val constraint = PrimaryKeyConstraint(columns = Seq("a")) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with named primary key - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT pk1 PRIMARY KEY (a)) USING parquet" - val constraint = PrimaryKeyConstraint( - columns = Seq("a"), - name = "pk1" - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with composite primary key - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a, b)) USING parquet" - val constraint = PrimaryKeyConstraint( - columns = Seq("a", "b") - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with primary key - column level") { - val sql = "CREATE TABLE t (a INT PRIMARY KEY, b STRING) USING parquet" - val constraint = PrimaryKeyConstraint( - columns = Seq("a") - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with named primary key - column level") { - val sql = "CREATE TABLE t (a INT CONSTRAINT pk1 PRIMARY KEY, b STRING) USING parquet" - val constraint = PrimaryKeyConstraint( - columns = Seq("a"), - name = "pk1" - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with multiple primary keys should fail") { - val expectedContext = ExpectedContext( - fragment = "a INT PRIMARY KEY, b STRING, PRIMARY KEY (b)", - start = 16, - stop = 59 - ) - checkError( - exception = intercept[ParseException] { - parsePlan("CREATE TABLE t (a INT PRIMARY KEY, b STRING, PRIMARY KEY (b)) USING parquet") - }, - condition = "MULTIPLE_PRIMARY_KEYS", - parameters = Map.empty[String, String], - queryContext = Array(expectedContext)) - } - - test("Create table with unique constraint - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, UNIQUE (a)) USING parquet" - val constraint = UniqueConstraint(columns = Seq("a")) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with named unique constraint - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT uk1 UNIQUE (a)) USING parquet" - val constraint = UniqueConstraint( - columns = Seq("a"), - name = "uk1" - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with composite unique constraint - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, UNIQUE (a, b)) USING parquet" - val constraint = UniqueConstraint( - columns = Seq("a", "b") - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with unique constraint - column level") { - val sql = "CREATE TABLE t (a INT UNIQUE, b STRING) USING parquet" - val constraint = UniqueConstraint( - columns = Seq("a") - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with named unique constraint - column level") { - val sql = "CREATE TABLE t (a INT CONSTRAINT uk1 UNIQUE, b STRING) USING parquet" - val constraint = UniqueConstraint( - columns = Seq("a"), - name = "uk1" - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with multiple unique constraints") { - val sql = "CREATE TABLE t (a INT UNIQUE, b STRING, UNIQUE (b)) USING parquet" - val constraint1 = UniqueConstraint(columns = Seq("a")) - val constraint2 = UniqueConstraint(columns = Seq("b")) - val constraints = Seq(constraint1, constraint2) - verifyConstraints(sql, constraints) - } - - test("Create table with foreign key - table level") { - val sql = "CREATE TABLE t (a INT, b STRING," + - " FOREIGN KEY (a) REFERENCES parent(id)) USING parquet" - val constraint = ForeignKeyConstraint( - childColumns = Seq("a"), - parentTableId = Seq("parent"), - parentColumns = Seq("id") - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with named foreign key - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT fk1 FOREIGN KEY (a)" + - " REFERENCES parent(id)) USING parquet" - val constraint = ForeignKeyConstraint( - childColumns = Seq("a"), - parentTableId = Seq("parent"), - parentColumns = Seq("id"), - name = "fk1" - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with foreign key - column level") { - val sql = "CREATE TABLE t (a INT REFERENCES parent(id), b STRING) USING parquet" - val constraint = ForeignKeyConstraint( - childColumns = Seq("a"), - parentTableId = Seq("parent"), - parentColumns = Seq("id") - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - - test("Create table with named foreign key - column level") { - val sql = "CREATE TABLE t (a INT CONSTRAINT fk1 REFERENCES parent(id), b STRING) USING parquet" - val constraint = ForeignKeyConstraint( - childColumns = Seq("a"), - parentTableId = Seq("parent"), - parentColumns = Seq("id"), - name = "fk1" - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala new file mode 100644 index 000000000000..2c119610e008 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.UnresolvedTable +import org.apache.spark.sql.catalyst.expressions.ForeignKeyConstraint +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.AddConstraint + +class ForeignKeyConstraintParseSuite extends ConstraintParseSuiteBase { + test("Create table with foreign key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING," + + " FOREIGN KEY (a) REFERENCES parent(id)) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named foreign key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT fk1 FOREIGN KEY (a)" + + " REFERENCES parent(id)) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id"), + name = "fk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with foreign key - column level") { + val sql = "CREATE TABLE t (a INT REFERENCES parent(id), b STRING) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named foreign key - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT fk1 REFERENCES parent(id), b STRING) USING parquet" + val constraint = ForeignKeyConstraint( + childColumns = Seq("a"), + parentTableId = Seq("parent"), + parentColumns = Seq("id"), + name = "fk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Add foreign key constraint") { + Seq(("", null), ("CONSTRAINT fk1", "fk1")).foreach { case (constraintName, expectedName) => + val sql = + s""" + |ALTER TABLE orders ADD $constraintName FOREIGN KEY (customer_id) + |REFERENCES customers (id) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("orders"), + "ALTER TABLE ... ADD CONSTRAINT"), + ForeignKeyConstraint( + name = expectedName, + childColumns = Seq("customer_id"), + parentTableId = Seq("customers"), + parentColumns = Seq("id") + )) + comparePlans(parsed, expected) + } + } + + test("Add invalid foreign key constraint name") { + val sql = + """ + |ALTER TABLE orders ADD CONSTRAINT fk-1 FOREIGN KEY (customer_id) + |REFERENCES customers (id) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) + } + + test("Add foreign key constraint with empty columns") { + val sql = + """ + |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY () + |REFERENCES customers (id) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) + } + + test("Add foreign key constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY (customer_id) + |REFERENCES customers (id) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("orders"), + "ALTER TABLE ... ADD CONSTRAINT"), + ForeignKeyConstraint( + name = "fk1", + childColumns = Seq("customer_id"), + parentTableId = Seq("customers"), + parentColumns = Seq("id"), + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add foreign key constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s""" + |ALTER TABLE orders ADD CONSTRAINT fk1 FOREIGN KEY (customer_id) + |REFERENCES customers (id) $characteristic1 $characteristic2 + |""".stripMargin + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT fk1 FOREIGN KEY (customer_id)\nREFERENCES customers (id) " + + s"$characteristic1 $characteristic2", + start = 24, + stop = 91 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for foreign key -- create table with unnamed constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT REFERENCES parent(id) $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"REFERENCES parent(id) $characteristic", + start = 23, + stop = 44 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "FOREIGN KEY"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for foreign key -- create table with named constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT, CONSTRAINT fk1 FOREIGN KEY (id)" + + s" REFERENCES parent(id) $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT fk1 FOREIGN KEY (id) REFERENCES parent(id) $characteristic", + start = 24, + stop = 77 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "FOREIGN KEY"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for foreign key -- alter table") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT fk1 FOREIGN KEY (id)" + + s" REFERENCES parent(id) $characteristic" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT fk1 FOREIGN KEY (id) REFERENCES parent(id) $characteristic", + start = 22, + stop = 75 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "FOREIGN KEY"), + queryContext = Array(expectedContext)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala new file mode 100644 index 000000000000..2d704cbe2659 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.UnresolvedTable +import org.apache.spark.sql.catalyst.expressions.PrimaryKeyConstraint +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.AddConstraint + +class PrimaryKeyConstraintParseSuite extends ConstraintParseSuiteBase { + + test("Create table with primary key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a)) USING parquet" + val constraint = PrimaryKeyConstraint(columns = Seq("a")) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named primary key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT pk1 PRIMARY KEY (a)) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a"), + name = "pk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with composite primary key - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a, b)) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a", "b") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with primary key - column level") { + val sql = "CREATE TABLE t (a INT PRIMARY KEY, b STRING) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named primary key - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT pk1 PRIMARY KEY, b STRING) USING parquet" + val constraint = PrimaryKeyConstraint( + columns = Seq("a"), + name = "pk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with multiple primary keys should fail") { + val expectedContext = ExpectedContext( + fragment = "a INT PRIMARY KEY, b STRING, PRIMARY KEY (b)", + start = 16, + stop = 59 + ) + checkError( + exception = intercept[ParseException] { + parsePlan("CREATE TABLE t (a INT PRIMARY KEY, b STRING, PRIMARY KEY (b)) USING parquet") + }, + condition = "MULTIPLE_PRIMARY_KEYS", + parameters = Map.empty[String, String], + queryContext = Array(expectedContext)) + } + + test("Add primary key constraint") { + Seq(("", null), ("CONSTRAINT pk1", "pk1")).foreach { case (constraintName, expectedName) => + val sql = + s""" + |ALTER TABLE a.b.c ADD $constraintName PRIMARY KEY (id, name) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + PrimaryKeyConstraint( + name = expectedName, + columns = Seq("id", "name") + )) + comparePlans(parsed, expected) + } + } + + test("Add invalid primary key constraint name") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT pk-1 PRIMARY KEY (id) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) + } + + test("Add primary key constraint with empty columns") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY () + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) + } + + test("Add primary key constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + PrimaryKeyConstraint( + name = "pk1", + columns = Seq("id"), + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add primary key constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $characteristic1 $characteristic2" + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic1 $characteristic2", + start = 22, + stop = 54 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for primary key -- create table with unnamed constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT PRIMARY KEY $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"PRIMARY KEY $characteristic", + start = 23, + stop = 34 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "PRIMARY KEY"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for primary key -- create table with named constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT, CONSTRAINT pk1 PRIMARY KEY (id) $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic", + start = 24, + stop = 55 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "PRIMARY KEY"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for primary key -- alter table") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT pk1 PRIMARY KEY (id) $characteristic" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT pk1 PRIMARY KEY (id) $characteristic", + start = 22, + stop = 53 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "PRIMARY KEY"), + queryContext = Array(expectedContext)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala new file mode 100644 index 000000000000..6c6cf851c95b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.UnresolvedTable +import org.apache.spark.sql.catalyst.expressions.UniqueConstraint +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.AddConstraint + +class UniqueConstraintParseSuite extends ConstraintParseSuiteBase { + test("Create table with unique constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, UNIQUE (a)) USING parquet" + val constraint = UniqueConstraint(columns = Seq("a")) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named unique constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT uk1 UNIQUE (a)) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a"), + name = "uk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with composite unique constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, UNIQUE (a, b)) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a", "b") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with unique constraint - column level") { + val sql = "CREATE TABLE t (a INT UNIQUE, b STRING) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a") + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with named unique constraint - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT uk1 UNIQUE, b STRING) USING parquet" + val constraint = UniqueConstraint( + columns = Seq("a"), + name = "uk1" + ) + val constraints = Seq(constraint) + verifyConstraints(sql, constraints) + } + + test("Create table with multiple unique constraints") { + val sql = "CREATE TABLE t (a INT UNIQUE, b STRING, UNIQUE (b)) USING parquet" + val constraint1 = UniqueConstraint(columns = Seq("a")) + val constraint2 = UniqueConstraint(columns = Seq("b")) + val constraints = Seq(constraint1, constraint2) + verifyConstraints(sql, constraints) + } + + test("Add unique constraint") { + Seq(("", null), ("CONSTRAINT uk1", "uk1")).foreach { case (constraintName, expectedName) => + val sql = + s""" + |ALTER TABLE a.b.c ADD $constraintName UNIQUE (email, username) + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + UniqueConstraint( + name = expectedName, + columns = Seq("email", "username") + )) + comparePlans(parsed, expected) + } + } + + test("Add invalid unique constraint name") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT uk-1 UNIQUE (email) + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "'-'", "hint" -> "")) + } + + test("Add unique constraint with empty columns") { + val sql = + """ + |ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE () + |""".stripMargin + val e = intercept[ParseException] { + parsePlan(sql) + } + checkError(e, "PARSE_SYNTAX_ERROR", "42601", Map("error" -> "')'", "hint" -> "")) + } + + test("Add unique constraint with valid characteristic") { + validConstraintCharacteristics.foreach { case (enforcedStr, relyStr, characteristic) => + val sql = + s""" + |ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE (email) $enforcedStr $relyStr + |""".stripMargin + val parsed = parsePlan(sql) + val expected = AddConstraint( + UnresolvedTable( + Seq("a", "b", "c"), + "ALTER TABLE ... ADD CONSTRAINT"), + UniqueConstraint( + name = "uk1", + columns = Seq("email"), + characteristic = characteristic + )) + comparePlans(parsed, expected) + } + } + + test("Add unique constraint with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT uk1 UNIQUE (email) $characteristic1 $characteristic2" + + val e = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT uk1 UNIQUE (email) $characteristic1 $characteristic2", + start = 22, + stop = 52 + characteristic1.length + characteristic2.length + ) + checkError( + exception = e, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + + test("ENFORCED is not supported for unique -- create table with unnamed constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT UNIQUE $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"UNIQUE $characteristic", + start = 23, + stop = 29 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "UNIQUE"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for unique -- create table with named constraint") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"CREATE TABLE t (id INT CONSTRAINT uk1 UNIQUE $characteristic) USING parquet" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT uk1 UNIQUE $characteristic", + start = 23, + stop = 44 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "UNIQUE"), + queryContext = Array(expectedContext)) + } + } + + test("ENFORCED is not supported for unique -- alter table") { + enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => + val characteristic = if (c2.isEmpty) { + c1 + } else { + s"$c1 $c2" + } + val sql = + s"ALTER TABLE a.b.c ADD CONSTRAINT uni UNIQUE (id) $characteristic" + val error = intercept[ParseException] { + parsePlan(sql) + } + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT uni UNIQUE (id) $characteristic", + start = 22, + stop = 48 + characteristic.length + ) + checkError( + exception = error, + condition = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC", + parameters = Map("characteristic" -> "ENFORCED", "constraintType" -> "UNIQUE"), + queryContext = Array(expectedContext)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala index 731754062467..ae404aff274a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/PrimaryKeyConstraintSuite.scala @@ -19,10 +19,18 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.execution.command.DDLCommandTestUtils - class PrimaryKeyConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" + private val validConstraintCharacteristics = Seq( + ("", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("RELY", "NOT ENFORCED UNVALIDATED RELY") + ) + test("Add primary key constraint") { validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => withNamespaceAndTable("ns", "tbl", catalog) { t => @@ -89,15 +97,4 @@ class PrimaryKeyConstraintSuite extends QueryTest with CommandSuiteBase with DDL "CONSTRAINT pk1 PRIMARY KEY (id1, id2) NOT ENFORCED UNVALIDATED NORELY") } } - - val validConstraintCharacteristics = Seq( - ("", "NOT ENFORCED UNVALIDATED NORELY"), - ("NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), - ("NOT ENFORCED NORELY", "NOT ENFORCED UNVALIDATED NORELY"), - ("NORELY NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), - ("NORELY", "NOT ENFORCED UNVALIDATED NORELY"), - ("ENFORCED RELY", "ENFORCED UNVALIDATED RELY"), - ("RELY ENFORCED", "ENFORCED UNVALIDATED RELY"), - ("RELY", "NOT ENFORCED UNVALIDATED RELY") - ) } From d1abaf35d703363417c8eba35062540631f8b67b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 31 Mar 2025 15:23:16 -0700 Subject: [PATCH 55/65] add test for unique and fk --- .../v2/ForeignKeyConstraintSuite.scala | 112 ++++++++++++++++++ .../command/v2/UniqueConstraintSuite.scala | 100 ++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ForeignKeyConstraintSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/UniqueConstraintSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ForeignKeyConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ForeignKeyConstraintSuite.scala new file mode 100644 index 000000000000..14abe9d1ab9b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ForeignKeyConstraintSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + +class ForeignKeyConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" + + private val validConstraintCharacteristics = Seq( + ("", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("RELY", "NOT ENFORCED UNVALIDATED RELY") + ) + + test("Add foreign key constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, fk bigint, data string) $defaultUsing") + sql(s"CREATE TABLE ${t}_ref (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT fk1 FOREIGN KEY (fk) " + + s"REFERENCES ${t}_ref(id) $characteristic") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "fk1") + assert(constraint.toDDL == s"CONSTRAINT fk1 FOREIGN KEY (fk) " + + s"REFERENCES test_catalog.ns.tbl_ref (id) $expectedDDL") + } + } + } + + test("Create table with foreign key constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE ${t}_ref (id bigint, data string) $defaultUsing") + val constraintStr = s"CONSTRAINT fk1 FOREIGN KEY (fk) " + + s"REFERENCES ${t}_ref(id) $characteristic" + sql(s"CREATE TABLE $t (id bigint, fk bigint, data string, $constraintStr) $defaultUsing") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "fk1") + assert(constraint.toDDL == s"CONSTRAINT fk1 FOREIGN KEY (fk) " + + s"REFERENCES test_catalog.ns.tbl_ref (id) $expectedDDL") + } + } + } + + test("Add duplicated foreign key constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, fk bigint, data string) $defaultUsing") + sql(s"CREATE TABLE ${t}_ref (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT fk1 FOREIGN KEY (fk) REFERENCES ${t}_ref(id)") + // Constraint names are case-insensitive + Seq("fk1", "FK1").foreach { name => + val error = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD CONSTRAINT $name FOREIGN KEY (fk) REFERENCES ${t}_ref(id)") + } + checkError( + exception = error, + condition = "CONSTRAINT_ALREADY_EXISTS", + sqlState = "42710", + parameters = Map("constraintName" -> "fk1", + "oldConstraint" -> + ("CONSTRAINT fk1 FOREIGN KEY (fk) " + + "REFERENCES test_catalog.ns.tbl_ref (id) NOT ENFORCED UNVALIDATED NORELY")) + ) + } + } + } + + test("Add foreign key constraint with multiple columns") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id1 bigint, id2 bigint, fk1 bigint, fk2 bigint, data string) " + + s"$defaultUsing") + sql(s"CREATE TABLE ${t}_ref (id1 bigint, id2 bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT fk1 FOREIGN KEY (fk1, fk2) REFERENCES ${t}_ref(id1, id2)") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "fk1") + assert(constraint.toDDL == + s"CONSTRAINT fk1 FOREIGN KEY (fk1, fk2) " + + s"REFERENCES test_catalog.ns.tbl_ref (id1, id2) NOT ENFORCED UNVALIDATED NORELY") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/UniqueConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/UniqueConstraintSuite.scala new file mode 100644 index 000000000000..4eee2c248cfd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/UniqueConstraintSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.execution.command.DDLCommandTestUtils + +class UniqueConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { + override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" + + private val validConstraintCharacteristics = Seq( + ("", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NOT ENFORCED NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY NOT ENFORCED", "NOT ENFORCED UNVALIDATED NORELY"), + ("NORELY", "NOT ENFORCED UNVALIDATED NORELY"), + ("RELY", "NOT ENFORCED UNVALIDATED RELY") + ) + + test("Add unique constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT uk1 UNIQUE (id) $characteristic") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "uk1") + assert(constraint.toDDL == s"CONSTRAINT uk1 UNIQUE (id) $expectedDDL") + } + } + } + + test("Create table with unique constraint") { + validConstraintCharacteristics.foreach { case (characteristic, expectedDDL) => + withNamespaceAndTable("ns", "tbl", catalog) { t => + val constraintStr = s"CONSTRAINT uk1 UNIQUE (id) $characteristic" + sql(s"CREATE TABLE $t (id bigint, data string, $constraintStr) $defaultUsing") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "uk1") + assert(constraint.toDDL == s"CONSTRAINT uk1 UNIQUE (id) $expectedDDL") + } + } + } + + test("Add duplicated unique constraint") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT uk1 UNIQUE (id)") + // Constraint names are case-insensitive + Seq("uk1", "UK1").foreach { name => + val error = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD CONSTRAINT $name UNIQUE (id)") + } + checkError( + exception = error, + condition = "CONSTRAINT_ALREADY_EXISTS", + sqlState = "42710", + parameters = Map("constraintName" -> "uk1", + "oldConstraint" -> "CONSTRAINT uk1 UNIQUE (id) NOT ENFORCED UNVALIDATED NORELY") + ) + } + } + } + + test("Add unique constraint with multiple columns") { + withNamespaceAndTable("ns", "tbl", catalog) { t => + sql(s"CREATE TABLE $t (id1 bigint, id2 bigint, data string) $defaultUsing") + assert(loadTable(catalog, "ns", "tbl").constraints.isEmpty) + + sql(s"ALTER TABLE $t ADD CONSTRAINT uk1 UNIQUE (id1, id2)") + val table = loadTable(catalog, "ns", "tbl") + assert(table.constraints.length == 1) + val constraint = table.constraints.head + assert(constraint.name() == "uk1") + assert(constraint.toDDL == + "CONSTRAINT uk1 UNIQUE (id1, id2) NOT ENFORCED UNVALIDATED NORELY") + } + } +} From 528b6ff59c44b327113834535571cb581e277569 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 31 Mar 2025 21:42:42 -0700 Subject: [PATCH 56/65] save for now --- .../catalyst/expressions/constraints.scala | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index d233ae6d4120..120970c03868 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.antlr.v4.runtime.ParserRuleContext -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.constraints.Constraint @@ -27,7 +26,7 @@ import org.apache.spark.sql.types.{DataType, StringType} trait TableConstraint { // Convert to a data source v2 constraint - def asConstraint(isCreateTable: Boolean): Constraint + def asConstraint(tableName: String, isCreateTable: Boolean): Constraint def withNameAndCharacteristic( name: String, @@ -38,9 +37,10 @@ trait TableConstraint { def characteristic: ConstraintCharacteristic - def defaultName: String + // Generate a constraint name based on the table name if the name is not specified + protected def generatedName(tableName: String): String - def defaultConstraintCharacteristic: ConstraintCharacteristic + protected def defaultConstraintCharacteristic: ConstraintCharacteristic protected def getCharacteristicValues: (Boolean, Boolean) = { val rely = characteristic.rely.getOrElse(defaultConstraintCharacteristic.rely.get) @@ -64,10 +64,10 @@ case class CheckConstraint( with Unevaluable with TableConstraint { - def asConstraint(isCreateTable: Boolean): Constraint = { + def asConstraint(tableName: String, isCreateTable: Boolean): Constraint = { val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull val (rely, enforced) = getCharacteristicValues - val constraintName = if (name == null) defaultName else name + val constraintName = if (name == null) generatedName(tableName) else name // The validation status is set to UNVALIDATED for create table and // VALID for alter table. val validateStatus = if (isCreateTable) { @@ -95,10 +95,11 @@ case class CheckConstraint( copy(name = name, characteristic = c) } - override def defaultName: String = - throw new AnalysisException( - errorClass = "INVALID_CHECK_CONSTRAINT.MISSING_NAME", - messageParameters = Map.empty) + override protected def generatedName(tableName: String): String = { + val base = condition.filter(_.isLetterOrDigit).take(20) + val rand = scala.util.Random.alphanumeric.take(6).mkString + s"${tableName}_chk_${base}_$rand" + } override def defaultConstraintCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic(enforced = Some(true), rely = Some(false)) @@ -114,9 +115,9 @@ case class PrimaryKeyConstraint( override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) extends TableConstraint { - override def asConstraint(isCreateTable: Boolean): Constraint = { + override def asConstraint(tableName: String, isCreateTable: Boolean): Constraint = { val (rely, enforced) = getCharacteristicValues - val constraintName = if (name == null) defaultName else name + val constraintName = if (name == null) generatedName(tableName) else name Constraint .primaryKey(constraintName, columns.map(FieldReference.column).toArray) .rely(rely) @@ -141,7 +142,7 @@ case class PrimaryKeyConstraint( copy(name = name, characteristic = c) } - override def defaultName: String = "pk" + override protected def generatedName(tableName: String): String = s"${tableName}_pk" override def defaultConstraintCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) @@ -153,9 +154,9 @@ case class UniqueConstraint( override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) extends TableConstraint { - override def asConstraint(isCreateTable: Boolean): Constraint = { + override def asConstraint(tableName: String, isCreateTable: Boolean): Constraint = { val (rely, enforced) = getCharacteristicValues - val constraintName = if (name == null) defaultName else name + val constraintName = if (name == null) generatedName(tableName) else name Constraint .unique(constraintName, columns.map(FieldReference.column).toArray) .rely(rely) @@ -180,10 +181,11 @@ case class UniqueConstraint( copy(name = name, characteristic = c) } - override def defaultName: String = - throw new AnalysisException( - errorClass = "INVALID_UNIQUE_CONSTRAINT.MISSING_NAME", - messageParameters = Map.empty) + override protected def generatedName(tableName: String): String = { + val base = columns.map(_.filter(_.isLetterOrDigit)).sorted.mkString("_").take(20) + val rand = scala.util.Random.alphanumeric.take(6).mkString + s"${tableName}_uk_${base}_$rand" + } override def defaultConstraintCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) @@ -198,9 +200,9 @@ case class ForeignKeyConstraint( extends TableConstraint { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - override def asConstraint(isCreateTable: Boolean): Constraint = { + override def asConstraint(tableName: String, isCreateTable: Boolean): Constraint = { val (rely, enforced) = getCharacteristicValues - val constraintName = if (name == null) defaultName else name + val constraintName = if (name == null) generatedName(tableName) else name Constraint .foreignKey(constraintName, childColumns.map(FieldReference.column).toArray, @@ -228,7 +230,8 @@ case class ForeignKeyConstraint( copy(name = name, characteristic = c) } - override def defaultName: String = "fk" + override protected def generatedName(tableName: String): String = + s"${tableName}_fk_${parentTableId.last}" override def defaultConstraintCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) From d14f5e586c706a80abc151d7e32dc743a0d80ea6 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 1 Apr 2025 11:49:29 -0700 Subject: [PATCH 57/65] support generated name --- .../sql/catalyst/parser/SqlBaseParser.g4 | 3 +- .../catalyst/expressions/constraints.scala | 43 +++++++------ .../sql/catalyst/parser/AstBuilder.scala | 61 ++++++++++--------- 3 files changed, 60 insertions(+), 47 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f402e9625063..f930552444a9 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1535,7 +1535,8 @@ columnConstraintDefinition ; columnConstraint - : uniqueSpec + : checkConstraint + | uniqueSpec | referenceSpec ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index 120970c03868..fd3cb1d08f7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types.{DataType, StringType} trait TableConstraint { // Convert to a data source v2 constraint - def asConstraint(tableName: String, isCreateTable: Boolean): Constraint + def asConstraint(isCreateTable: Boolean): Constraint def withNameAndCharacteristic( name: String, @@ -37,8 +37,19 @@ trait TableConstraint { def characteristic: ConstraintCharacteristic + def generateConstraintNameIfNeeded(tableName: String): TableConstraint = { + if (name == null || name.isEmpty) { + this.withNameAndCharacteristic( + name = generateConstraintName(tableName), + c = characteristic, + ctx = null) + } else { + this + } + } + // Generate a constraint name based on the table name if the name is not specified - protected def generatedName(tableName: String): String + protected def generateConstraintName(tableName: String): String protected def defaultConstraintCharacteristic: ConstraintCharacteristic @@ -64,10 +75,9 @@ case class CheckConstraint( with Unevaluable with TableConstraint { - def asConstraint(tableName: String, isCreateTable: Boolean): Constraint = { + def asConstraint(isCreateTable: Boolean): Constraint = { val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull val (rely, enforced) = getCharacteristicValues - val constraintName = if (name == null) generatedName(tableName) else name // The validation status is set to UNVALIDATED for create table and // VALID for alter table. val validateStatus = if (isCreateTable) { @@ -76,7 +86,7 @@ case class CheckConstraint( Constraint.ValidationStatus.VALID } Constraint - .check(constraintName) + .check(name) .sql(condition) .predicate(predicate) .rely(rely) @@ -95,7 +105,7 @@ case class CheckConstraint( copy(name = name, characteristic = c) } - override protected def generatedName(tableName: String): String = { + override protected def generateConstraintName(tableName: String): String = { val base = condition.filter(_.isLetterOrDigit).take(20) val rand = scala.util.Random.alphanumeric.take(6).mkString s"${tableName}_chk_${base}_$rand" @@ -115,11 +125,10 @@ case class PrimaryKeyConstraint( override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) extends TableConstraint { - override def asConstraint(tableName: String, isCreateTable: Boolean): Constraint = { + override def asConstraint(isCreateTable: Boolean): Constraint = { val (rely, enforced) = getCharacteristicValues - val constraintName = if (name == null) generatedName(tableName) else name Constraint - .primaryKey(constraintName, columns.map(FieldReference.column).toArray) + .primaryKey(name, columns.map(FieldReference.column).toArray) .rely(rely) .enforced(enforced) .validationStatus(Constraint.ValidationStatus.UNVALIDATED) @@ -142,7 +151,7 @@ case class PrimaryKeyConstraint( copy(name = name, characteristic = c) } - override protected def generatedName(tableName: String): String = s"${tableName}_pk" + override protected def generateConstraintName(tableName: String): String = s"${tableName}_pk" override def defaultConstraintCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) @@ -154,11 +163,10 @@ case class UniqueConstraint( override val characteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) extends TableConstraint { - override def asConstraint(tableName: String, isCreateTable: Boolean): Constraint = { + override def asConstraint(isCreateTable: Boolean): Constraint = { val (rely, enforced) = getCharacteristicValues - val constraintName = if (name == null) generatedName(tableName) else name Constraint - .unique(constraintName, columns.map(FieldReference.column).toArray) + .unique(name, columns.map(FieldReference.column).toArray) .rely(rely) .enforced(enforced) .validationStatus(Constraint.ValidationStatus.UNVALIDATED) @@ -181,7 +189,7 @@ case class UniqueConstraint( copy(name = name, characteristic = c) } - override protected def generatedName(tableName: String): String = { + override protected def generateConstraintName(tableName: String): String = { val base = columns.map(_.filter(_.isLetterOrDigit)).sorted.mkString("_").take(20) val rand = scala.util.Random.alphanumeric.take(6).mkString s"${tableName}_uk_${base}_$rand" @@ -200,11 +208,10 @@ case class ForeignKeyConstraint( extends TableConstraint { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - override def asConstraint(tableName: String, isCreateTable: Boolean): Constraint = { + override def asConstraint(isCreateTable: Boolean): Constraint = { val (rely, enforced) = getCharacteristicValues - val constraintName = if (name == null) generatedName(tableName) else name Constraint - .foreignKey(constraintName, + .foreignKey(name, childColumns.map(FieldReference.column).toArray, parentTableId.asIdentifier, parentColumns.map(FieldReference.column).toArray) @@ -230,7 +237,7 @@ case class ForeignKeyConstraint( copy(name = name, characteristic = c) } - override protected def generatedName(tableName: String): String = + override protected def generateConstraintName(tableName: String): String = s"${tableName}_fk_${parentTableId.last}" override def defaultConstraintCharacteristic: ConstraintCharacteristic = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index fa10d96b75d5..3ab0d0e6152f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4842,33 +4842,37 @@ class AstBuilder extends DataTypeAstBuilder bucketSpec.map(_.asTransform) ++ clusterBySpec.map(_.asTransform) - val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external, constraints) - - Option(ctx.query).map(plan) match { - case Some(_) if columns.nonEmpty => - operationNotAllowed( - "Schema may not be specified in a Create Table As Select (CTAS) statement", - ctx) + val planOpt = Option(ctx.query).map(plan) + withIdentClause(identifierContext, planOpt.toSeq, (identifiers, otherPlans) => { + val namedConstraints = + constraints.map(c => c.generateConstraintNameIfNeeded(identifiers.last)) + val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, + collation, serdeInfo, external, namedConstraints) + val identifier = withOrigin(identifierContext) { + UnresolvedIdentifier(identifiers) + } + otherPlans.headOption match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) - case Some(_) if partCols.nonEmpty => - // non-reference partition columns are not allowed because schema can't be specified - operationNotAllowed( - "Partition column types may not be specified in Create Table As Select (CTAS)", - ctx) + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Create Table As Select (CTAS)", + ctx) - case Some(query) => - CreateTableAsSelect(withIdentClause(identifierContext, UnresolvedIdentifier(_)), - partitioning, query, tableSpec, Map.empty, ifNotExists) + case Some(query) => + CreateTableAsSelect(identifier, partitioning, query, tableSpec, Map.empty, ifNotExists) - case _ => - // Note: table schema includes both the table columns list and the partition columns - // with data type. - val allColumns = columns ++ partCols - CreateTable( - withIdentClause(identifierContext, UnresolvedIdentifier(_)), - allColumns, partitioning, tableSpec, ignoreIfExists = ifNotExists) - } + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val allColumns = columns ++ partCols + CreateTable(identifier, allColumns, partitioning, tableSpec, ignoreIfExists = ifNotExists) + } + }) } /** @@ -5424,13 +5428,14 @@ class AstBuilder extends DataTypeAstBuilder */ override def visitAddTableConstraint(ctx: AddTableConstraintContext): LogicalPlan = withOrigin(ctx) { - val table = createUnresolvedTable( - ctx.identifierReference, "ALTER TABLE ... ADD CONSTRAINT") val tableConstraint = visitTableConstraintDefinition(ctx.tableConstraintDefinition()) - AddConstraint(table, tableConstraint) + withIdentClause(ctx.identifierReference, identifiers => { + val table = UnresolvedTable(identifiers, "ALTER TABLE ... ADD CONSTRAINT") + val namedConstraint = tableConstraint.generateConstraintNameIfNeeded(identifiers.last) + AddConstraint(table, namedConstraint) + }) } - /** * Parse a [[DropConstraint]] command. * From 0b84a72c7608dc6158cfda8834a71a90e331088e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 1 Apr 2025 12:51:46 -0700 Subject: [PATCH 58/65] unnamed check constraint --- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../command/CheckConstraintParseSuite.scala | 40 ++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3ab0d0e6152f..cafe8f20e913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3969,7 +3969,9 @@ class AstBuilder extends DataTypeAstBuilder columnName: String, ctx: ColumnConstraintContext): TableConstraint = withOrigin(ctx) { val columns = Seq(columnName) - if (ctx.uniqueSpec() != null) { + if (ctx.checkConstraint() != null) { + visitCheckConstraint(ctx.checkConstraint()) + } else if (ctx.uniqueSpec() != null) { visitUniqueSpec(ctx.uniqueSpec(), columns) } else { assert(ctx.referenceSpec() != null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala index 58028ccb115e..28c00eec85c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedTa import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.{AddConstraint, ColumnDefinition} +import org.apache.spark.sql.catalyst.plans.logical.{AddConstraint, ColumnDefinition, CreateTable, UnresolvedTableSpec} import org.apache.spark.sql.types.StringType class CheckConstraintParseSuite extends ConstraintParseSuiteBase { @@ -174,4 +174,42 @@ class CheckConstraintParseSuite extends ConstraintParseSuiteBase { queryContext = Array(expectedContext)) } } + + test("Create table with unnamed check constraint") { + Seq( + "CREATE TABLE t (a INT, b STRING, CHECK (a > 0))", + "CREATE TABLE t (a INT CHECK (a > 0), b STRING)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: CreateTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[CheckConstraint]) + assert(tableSpec.constraints.head.name.matches("t_chk_a0_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected CreateTable, but got: $other") + } + } + } + + test("Add unnamed check constraint") { + val sql = + """ + |ALTER TABLE a.b.c ADD CHECK (d > 0) + |""".stripMargin + val plan = parsePlan(sql) + plan match { + case a: AddConstraint => + val table = a.table.asInstanceOf[UnresolvedTable] + assert(table.multipartIdentifier == Seq("a", "b", "c")) + assert(a.tableConstraint.isInstanceOf[CheckConstraint]) + assert(a.tableConstraint.name.matches("c_chk_d0_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected AddConstraint, but got: $other") + } + } + } From d3f7df453556621abca1cf46307829a61251d46d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 1 Apr 2025 13:16:23 -0700 Subject: [PATCH 59/65] simplify test --- .../command/CheckConstraintParseSuite.scala | 45 +++++++------------ 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala index 28c00eec85c3..04370a8d142e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala @@ -28,29 +28,25 @@ class CheckConstraintParseSuite extends ConstraintParseSuiteBase { override val validConstraintCharacteristics = super.validConstraintCharacteristics ++ super.enforcedConstraintCharacteristics + val constraint1 = CheckConstraint( + child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), + condition = "a > 0", + name = "c1") + val constraint2 = CheckConstraint( + child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), + condition = "b = 'foo'", + name = "c2") + test("Create table with one check constraint - table level") { val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0)) USING parquet" - val constraint = CheckConstraint( - child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a > 0", - name = "c1") - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) + verifyConstraints(sql, Seq(constraint1)) } test("Create table with two check constraints - table level") { val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0), " + "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" - val constraint1 = CheckConstraint( - child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a > 0", - name = "c1") - val constraint2 = CheckConstraint( - child = EqualTo(UnresolvedAttribute("b"), Literal("foo")), - condition = "b = 'foo'", - name = "c2") - val constraints = Seq(constraint1, constraint2) - verifyConstraints(sql, constraints) + + verifyConstraints(sql, Seq(constraint1, constraint2)) } test("Create table with valid characteristic - table level") { @@ -58,13 +54,8 @@ class CheckConstraintParseSuite extends ConstraintParseSuiteBase { case (enforcedStr, relyStr, characteristic) => val sql = s"CREATE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0) " + s"$enforcedStr $relyStr) USING parquet" - val constraint = CheckConstraint( - child = GreaterThan(UnresolvedAttribute("a"), Literal(0)), - condition = "a > 0", - name = "c1", - characteristic = characteristic) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) + val constraint = constraint1.withNameAndCharacteristic("c1", characteristic, null) + verifyConstraints(sql, Seq(constraint)) } } @@ -96,18 +87,14 @@ class CheckConstraintParseSuite extends ConstraintParseSuiteBase { test("Add check constraint") { val sql = """ - |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (d > 0) + |ALTER TABLE a.b.c ADD CONSTRAINT c1 CHECK (a > 0) |""".stripMargin val parsed = parsePlan(sql) val expected = AddConstraint( UnresolvedTable( Seq("a", "b", "c"), "ALTER TABLE ... ADD CONSTRAINT"), - CheckConstraint( - child = GreaterThan(UnresolvedAttribute("d"), Literal(0)), - condition = "d > 0", - name = "c1" - )) + constraint1) comparePlans(parsed, expected) } From 7e8273e01d58829dec6c2df7b82c6dfa5523d450 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 1 Apr 2025 13:38:35 -0700 Subject: [PATCH 60/65] support unnamed constraint in replace table;add test cases for replace table --- .../sql/catalyst/parser/AstBuilder.scala | 59 ++++++++-------- .../command/CheckConstraintParseSuite.scala | 68 ++++++++++++++++++- .../command/ConstraintParseSuiteBase.scala | 19 ++++-- 3 files changed, 112 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cafe8f20e913..6bdc5170b68c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4844,8 +4844,7 @@ class AstBuilder extends DataTypeAstBuilder bucketSpec.map(_.asTransform) ++ clusterBySpec.map(_.asTransform) - val planOpt = Option(ctx.query).map(plan) - withIdentClause(identifierContext, planOpt.toSeq, (identifiers, otherPlans) => { + withIdentClause(identifierContext, identifiers => { val namedConstraints = constraints.map(c => c.generateConstraintNameIfNeeded(identifiers.last)) val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, @@ -4853,7 +4852,7 @@ class AstBuilder extends DataTypeAstBuilder val identifier = withOrigin(identifierContext) { UnresolvedIdentifier(identifiers) } - otherPlans.headOption match { + Option(ctx.query).map(plan) match { case Some(_) if columns.nonEmpty => operationNotAllowed( "Schema may not be specified in a Create Table As Select (CTAS) statement", @@ -4922,34 +4921,38 @@ class AstBuilder extends DataTypeAstBuilder bucketSpec.map(_.asTransform) ++ clusterBySpec.map(_.asTransform) - val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, - collation, serdeInfo, external = false, constraints) - - Option(ctx.query).map(plan) match { - case Some(_) if columns.nonEmpty => - operationNotAllowed( - "Schema may not be specified in a Replace Table As Select (RTAS) statement", - ctx) + val identifierContext = ctx.replaceTableHeader().identifierReference() + withIdentClause(identifierContext, identifiers => { + val namedConstraints = + constraints.map(c => c.generateConstraintNameIfNeeded(identifiers.last)) + val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, + collation, serdeInfo, external = false, namedConstraints) + val identifier = withOrigin(identifierContext) { + UnresolvedIdentifier(identifiers) + } + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Replace Table As Select (RTAS) statement", + ctx) - case Some(_) if partCols.nonEmpty => - // non-reference partition columns are not allowed because schema can't be specified - operationNotAllowed( - "Partition column types may not be specified in Replace Table As Select (RTAS)", - ctx) + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Replace Table As Select (RTAS)", + ctx) - case Some(query) => - ReplaceTableAsSelect( - withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)), - partitioning, query, tableSpec, writeOptions = Map.empty, orCreate = orCreate) + case Some(query) => + ReplaceTableAsSelect(identifier, partitioning, query, tableSpec, + writeOptions = Map.empty, orCreate = orCreate) - case _ => - // Note: table schema includes both the table columns list and the partition columns - // with data type. - val allColumns = columns ++ partCols - ReplaceTable( - withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)), - allColumns, partitioning, tableSpec, orCreate = orCreate) - } + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val allColumns = columns ++ partCols + ReplaceTable(identifier, allColumns, partitioning, tableSpec, orCreate = orCreate) + } + }) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala index 04370a8d142e..6b55bf93b158 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedTa import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.{AddConstraint, ColumnDefinition, CreateTable, UnresolvedTableSpec} +import org.apache.spark.sql.catalyst.plans.logical.{AddConstraint, ColumnDefinition, CreateTable, ReplaceTable, UnresolvedTableSpec} import org.apache.spark.sql.types.StringType class CheckConstraintParseSuite extends ConstraintParseSuiteBase { @@ -84,6 +84,53 @@ class CheckConstraintParseSuite extends ConstraintParseSuiteBase { comparePlans(parsePlan(sql), expected) } + test("Replace table with one check constraint - table level") { + val sql = "REPLACE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0)) USING parquet" + verifyConstraints(sql, Seq(constraint1), isCreateTable = false) + } + + test("Replace table with two check constraints - table level") { + val sql = "REPLACE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0), " + + "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" + + verifyConstraints(sql, Seq(constraint1, constraint2), isCreateTable = false) + } + + test("Replace table with valid characteristic - table level") { + validConstraintCharacteristics.foreach { + case (enforcedStr, relyStr, characteristic) => + val sql = s"REPLACE TABLE t (a INT, b STRING, CONSTRAINT c1 CHECK (a > 0) " + + s"$enforcedStr $relyStr) USING parquet" + val constraint = constraint1.withNameAndCharacteristic("c1", characteristic, null) + verifyConstraints(sql, Seq(constraint), isCreateTable = false) + } + } + + test("Replace table with invalid characteristic") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2", + start = 34, + stop = 62 + characteristic1.length + characteristic2.length + ) + checkError( + exception = intercept[ParseException] { + parsePlan(s"REPLACE TABLE t (a INT, b STRING, $constraintStr ) USING parquet") + }, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + + test("Replace table with column 'constraint'") { + val sql = "REPLACE TABLE t (constraint STRING) USING parquet" + val columns = Seq(ColumnDefinition("constraint", StringType)) + val expected = createExpectedPlan(columns, Seq.empty, isCreateTable = false) + comparePlans(parsePlan(sql), expected) + } + test("Add check constraint") { val sql = """ @@ -181,6 +228,25 @@ class CheckConstraintParseSuite extends ConstraintParseSuiteBase { } } + test("Replace table with unnamed check constraint") { + Seq( + "REPLACE TABLE t (a INT, b STRING, CHECK (a > 0))", + "REPLACE TABLE t (a INT CHECK (a > 0), b STRING)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: ReplaceTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[CheckConstraint]) + assert(tableSpec.constraints.head.name.matches("t_chk_a0_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected ReplaceTable, but got: $other") + } + } + } + test("Add unnamed check constraint") { val sql = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala index e876207138c0..ea369489eb18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ConstraintParseSuiteBase.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{ConstraintCharacteristic, TableConstraint} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan -import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, OptionList, UnresolvedTableSpec} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, LogicalPlan, OptionList, ReplaceTable, UnresolvedTableSpec} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType} @@ -56,21 +56,30 @@ abstract class ConstraintParseSuiteBase extends AnalysisTest with SharedSparkSes protected def createExpectedPlan( columns: Seq[ColumnDefinition], - constraints: Seq[TableConstraint]): CreateTable = { + constraints: Seq[TableConstraint], + isCreateTable: Boolean = true): LogicalPlan = { val tableId = UnresolvedIdentifier(Seq("t")) val tableSpec = UnresolvedTableSpec( Map.empty[String, String], Some("parquet"), OptionList(Seq.empty), None, None, None, None, false, constraints) - CreateTable(tableId, columns, Seq.empty, tableSpec, false) + if (isCreateTable) { + CreateTable(tableId, columns, Seq.empty, tableSpec, false) + } else { + ReplaceTable(tableId, columns, Seq.empty, tableSpec, false) + } } - protected def verifyConstraints(sql: String, constraints: Seq[TableConstraint]): Unit = { + protected def verifyConstraints( + sql: String, + constraints: Seq[TableConstraint], + isCreateTable: Boolean = true): Unit = { val parsed = parsePlan(sql) val columns = Seq( ColumnDefinition("a", IntegerType), ColumnDefinition("b", StringType) ) - val expected = createExpectedPlan(columns = columns, constraints = constraints) + val expected = createExpectedPlan( + columns = columns, constraints = constraints, isCreateTable = isCreateTable) comparePlans(parsed, expected) } } From 5cc47a31c6d68e6c385a4251c1c389147bf509d8 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 1 Apr 2025 14:28:56 -0700 Subject: [PATCH 61/65] handle Nondeterministic check --- .../org/apache/spark/SparkFunSuite.scala | 21 ++++++-- .../catalyst/analysis/ResolveTableSpec.scala | 10 +++- .../command/v2/CheckConstraintSuite.scala | 49 +++++++++++++++++-- 3 files changed, 71 insertions(+), 9 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index e38efc27b78f..ca8326918fec 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -371,10 +371,17 @@ abstract class SparkFunSuite "Invalid objectType of a query context Actual:" + actual.toString) assert(actual.objectName() === expected.objectName, "Invalid objectName of a query context. Actual:" + actual.toString) - assert(actual.startIndex() === expected.startIndex, - "Invalid startIndex of a query context. Actual:" + actual.toString) - assert(actual.stopIndex() === expected.stopIndex, - "Invalid stopIndex of a query context. Actual:" + actual.toString) + // If startIndex and stopIndex are -1, it means we simply want to check the + // fragment of the query context. This should be the case when the fragment is + // distinguished within the query text. + if (expected.startIndex != -1) { + assert(actual.startIndex() === expected.startIndex, + "Invalid startIndex of a query context. Actual:" + actual.toString) + } + if (expected.stopIndex != -1) { + assert(actual.stopIndex() === expected.stopIndex, + "Invalid stopIndex of a query context. Actual:" + actual.toString) + } assert(actual.fragment() === expected.fragment, "Invalid fragment of a query context. Actual:" + actual.toString) } else if (actual.contextType() == QueryContextType.DataFrame) { @@ -478,6 +485,12 @@ abstract class SparkFunSuite ExpectedContext("", "", start, stop, fragment) } + // Check the fragment only. This is only used when the fragment is distinguished within + // the query text + def apply(fragment: String): ExpectedContext = { + ExpectedContext("", "", -1, -1, fragment) + } + def apply( objectType: String, objectName: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala index 48bd9dbaa0be..65debcb10228 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala @@ -76,10 +76,16 @@ object ResolveTableSpec extends Rule[LogicalPlan] { val analyzed = DefaultColumnAnalyzer.execute(project) DefaultColumnAnalyzer.checkAnalysis0(analyzed) - val analyzedExpression = analyzed collectFirst { + val analyzedExpression = (analyzed collectFirst { case Project(Seq(Alias(e: Expression, _)), _) => e + }).get + if (!analyzedExpression.deterministic) { + analyzedExpression.failAnalysis( + errorClass = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", + messageParameters = Map.empty + ) } - c.withNewChildren(Seq(analyzedExpression.get)).asInstanceOf[CheckConstraint] + c.withNewChildren(Seq(analyzedExpression)).asInstanceOf[CheckConstraint] case other => other } analyzedExpressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index 0d35fd4c48b4..818f4f42c65d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.execution.command.DDLCommandTestUtils class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLCommandTestUtils { override protected def command: String = "ALTER TABLE .. ADD CONSTRAINT" - test("Nondeterministic expression") { + test("Nondeterministic expression -- alter table") { withTable("t") { - sql("create table t(i double) using parquet") + sql("create table t(i double)") val query = """ |ALTER TABLE t ADD CONSTRAINT c1 CHECK (i > rand(0)) @@ -49,7 +49,31 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma } } - test("Expression referring a column of another table") { + test("Nondeterministic expression -- create table") { + Seq( + "create table t(i double check (i > rand(0)))", + "create table t(i double, constraint c1 check (i > rand(0)))", + "replace table t(i double check (i > rand(0)))", + "replace table t(i double, constraint c1 check (i > rand(0)))" + ).foreach { query => + withTable("t") { + val error = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = error, + condition = "INVALID_CHECK_CONSTRAINT.NONDETERMINISTIC", + sqlState = "42621", + parameters = Map.empty, + context = ExpectedContext( + fragment = "i > rand(0)" + ) + ) + } + } + } + + test("Expression referring a column of another table -- alter table") { withTable("t", "t2") { sql("create table t(i double) using parquet") sql("create table t2(j string) using parquet") @@ -74,6 +98,25 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma } } + test("Expression referring a column of another table -- create and replace table") { + withTable("t", "t2") { + sql("create table t(i double) using parquet") + val query = "create table t2(j string check(t.i > 0)) using parquet" + val error = intercept[AnalysisException] { + sql(query) + } + checkError( + exception = error, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`t`.`i`", "proposal" -> "`j`"), + context = ExpectedContext( + fragment = "t.i" + ) + ) + } + } + private def getCheckConstraint(table: Table): Check = { assert(table.constraints.length == 1) assert(table.constraints.head.isInstanceOf[Check]) From 675b5735220230bc98de24d561b5e78217df2f9e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 1 Apr 2025 15:25:08 -0700 Subject: [PATCH 62/65] add more tests in CheckConstraintParseSuite --- .../command/CheckConstraintParseSuite.scala | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala index 6b55bf93b158..bcbae0d72118 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CheckConstraintParseSuite.scala @@ -59,7 +59,7 @@ class CheckConstraintParseSuite extends ConstraintParseSuiteBase { } } - test("Create table with invalid characteristic") { + test("Create table with invalid characteristic - table level") { invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" val expectedContext = ExpectedContext( @@ -77,6 +77,52 @@ class CheckConstraintParseSuite extends ConstraintParseSuiteBase { } } + test("Create table with one check constraint - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), b STRING) USING parquet" + verifyConstraints(sql, Seq(constraint1)) + } + + test("Create table with two check constraints - column level") { + val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), " + + "b STRING CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" + verifyConstraints(sql, Seq(constraint1, constraint2)) + } + + test("Create table with mixed column and table level check constraints") { + val sql = "CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0), b STRING, " + + "CONSTRAINT c2 CHECK (b = 'foo')) USING parquet" + verifyConstraints(sql, Seq(constraint1, constraint2)) + } + + test("Create table with valid characteristic - column level") { + validConstraintCharacteristics.foreach { + case (enforcedStr, relyStr, characteristic) => + val sql = s"CREATE TABLE t (a INT CONSTRAINT c1 CHECK (a > 0)" + + s" $enforcedStr $relyStr, b STRING) USING parquet" + val constraint = constraint1.withNameAndCharacteristic("c1", characteristic, null) + verifyConstraints(sql, Seq(constraint)) + } + } + + test("Create table with invalid characteristic - column level") { + invalidConstraintCharacteristics.foreach { case (characteristic1, characteristic2) => + val constraintStr = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2" + val sql = s"CREATE TABLE t (a INT $constraintStr, b STRING) USING parquet" + val expectedContext = ExpectedContext( + fragment = s"CONSTRAINT c1 CHECK (a > 0) $characteristic1 $characteristic2", + start = 22, + stop = 50 + characteristic1.length + characteristic2.length + ) + checkError( + exception = intercept[ParseException] { + parsePlan(sql) + }, + condition = "INVALID_CONSTRAINT_CHARACTERISTICS", + parameters = Map("characteristics" -> s"$characteristic1, $characteristic2"), + queryContext = Array(expectedContext)) + } + } + test("Create table with column 'constraint'") { val sql = "CREATE TABLE t (constraint STRING) USING parquet" val columns = Seq(ColumnDefinition("constraint", StringType)) From e51cbd0fbfdcc9ead45145d09c2990844d881933 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 1 Apr 2025 15:31:13 -0700 Subject: [PATCH 63/65] fix PrimaryKeyConstraintParseSuite --- .../command/PrimaryKeyConstraintParseSuite.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala index 2d704cbe2659..e3b3f162fcae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PrimaryKeyConstraintParseSuite.scala @@ -27,7 +27,7 @@ class PrimaryKeyConstraintParseSuite extends ConstraintParseSuiteBase { test("Create table with primary key - table level") { val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a)) USING parquet" - val constraint = PrimaryKeyConstraint(columns = Seq("a")) + val constraint = PrimaryKeyConstraint(columns = Seq("a"), name = "t_pk") val constraints = Seq(constraint) verifyConstraints(sql, constraints) } @@ -44,18 +44,14 @@ class PrimaryKeyConstraintParseSuite extends ConstraintParseSuiteBase { test("Create table with composite primary key - table level") { val sql = "CREATE TABLE t (a INT, b STRING, PRIMARY KEY (a, b)) USING parquet" - val constraint = PrimaryKeyConstraint( - columns = Seq("a", "b") - ) + val constraint = PrimaryKeyConstraint(columns = Seq("a", "b"), name = "t_pk") val constraints = Seq(constraint) verifyConstraints(sql, constraints) } test("Create table with primary key - column level") { val sql = "CREATE TABLE t (a INT PRIMARY KEY, b STRING) USING parquet" - val constraint = PrimaryKeyConstraint( - columns = Seq("a") - ) + val constraint = PrimaryKeyConstraint(columns = Seq("a"), name = "t_pk") val constraints = Seq(constraint) verifyConstraints(sql, constraints) } @@ -86,7 +82,7 @@ class PrimaryKeyConstraintParseSuite extends ConstraintParseSuiteBase { } test("Add primary key constraint") { - Seq(("", null), ("CONSTRAINT pk1", "pk1")).foreach { case (constraintName, expectedName) => + Seq(("", "c_pk"), ("CONSTRAINT pk1", "pk1")).foreach { case (constraintName, expectedName) => val sql = s""" |ALTER TABLE a.b.c ADD $constraintName PRIMARY KEY (id, name) From aead1b2ac31417ee012480dc635e7359bd76f55b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 1 Apr 2025 15:43:12 -0700 Subject: [PATCH 64/65] fix ForeignKeyConstraintParseSuite --- .../apache/spark/sql/catalyst/expressions/constraints.scala | 2 +- .../execution/command/ForeignKeyConstraintParseSuite.scala | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index fd3cb1d08f7f..f4c171868ad3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -238,7 +238,7 @@ case class ForeignKeyConstraint( } override protected def generateConstraintName(tableName: String): String = - s"${tableName}_fk_${parentTableId.last}" + s"${tableName}_${parentTableId.last}_fk" override def defaultConstraintCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic(enforced = Some(false), rely = Some(false)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala index 2c119610e008..1df5e7b9deb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ForeignKeyConstraintParseSuite.scala @@ -27,6 +27,7 @@ class ForeignKeyConstraintParseSuite extends ConstraintParseSuiteBase { val sql = "CREATE TABLE t (a INT, b STRING," + " FOREIGN KEY (a) REFERENCES parent(id)) USING parquet" val constraint = ForeignKeyConstraint( + name = "t_parent_fk", childColumns = Seq("a"), parentTableId = Seq("parent"), parentColumns = Seq("id") @@ -51,6 +52,7 @@ class ForeignKeyConstraintParseSuite extends ConstraintParseSuiteBase { test("Create table with foreign key - column level") { val sql = "CREATE TABLE t (a INT REFERENCES parent(id), b STRING) USING parquet" val constraint = ForeignKeyConstraint( + name = "t_parent_fk", childColumns = Seq("a"), parentTableId = Seq("parent"), parentColumns = Seq("id") @@ -72,7 +74,9 @@ class ForeignKeyConstraintParseSuite extends ConstraintParseSuiteBase { } test("Add foreign key constraint") { - Seq(("", null), ("CONSTRAINT fk1", "fk1")).foreach { case (constraintName, expectedName) => + Seq( + ("", "orders_customers_fk"), + ("CONSTRAINT fk1", "fk1")).foreach { case (constraintName, expectedName) => val sql = s""" |ALTER TABLE orders ADD $constraintName FOREIGN KEY (customer_id) From dc24937559e1f3d2e899df06b5ca4d12eb6f9385 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 1 Apr 2025 17:03:22 -0700 Subject: [PATCH 65/65] fix UniqueConstraintParseSuite --- .../command/UniqueConstraintParseSuite.scala | 133 +++++++++++++----- 1 file changed, 99 insertions(+), 34 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala index 6c6cf851c95b..f33807e7cfa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/UniqueConstraintParseSuite.scala @@ -20,39 +20,111 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedTable import org.apache.spark.sql.catalyst.expressions.UniqueConstraint import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.AddConstraint +import org.apache.spark.sql.catalyst.plans.logical.{AddConstraint, CreateTable, ReplaceTable, UnresolvedTableSpec} class UniqueConstraintParseSuite extends ConstraintParseSuiteBase { - test("Create table with unique constraint - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, UNIQUE (a)) USING parquet" - val constraint = UniqueConstraint(columns = Seq("a")) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) - } - test("Create table with named unique constraint - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT uk1 UNIQUE (a)) USING parquet" - val constraint = UniqueConstraint( - columns = Seq("a"), - name = "uk1" - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) + test("Create table with unnamed unique constraint") { + Seq( + "CREATE TABLE t (a INT, b STRING, UNIQUE (a))", + "CREATE TABLE t (a INT UNIQUE, b STRING)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: CreateTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.head.name.matches("t_uk_a_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected CreateTable, but got: $other") + } + } } test("Create table with composite unique constraint - table level") { - val sql = "CREATE TABLE t (a INT, b STRING, UNIQUE (a, b)) USING parquet" - val constraint = UniqueConstraint( - columns = Seq("a", "b") - ) - val constraints = Seq(constraint) - verifyConstraints(sql, constraints) + Seq( + "CREATE TABLE t (a INT, b STRING, UNIQUE (a, b))", + "CREATE TABLE t (a INT, b STRING, UNIQUE (b, a))" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: CreateTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.head.name.matches("t_uk_a_b_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected CreateTable, but got: $other") + } + } + } + + test("Create table with multiple unique constraints") { + Seq( + "CREATE TABLE t (a INT UNIQUE, b STRING, UNIQUE (b))", + "CREATE TABLE t (a INT, UNIQUE (a), b STRING UNIQUE)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: CreateTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 2) + assert(tableSpec.constraints.head.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.head.name.matches("t_uk_a_[0-9a-zA-Z]{6}")) + assert(tableSpec.constraints.last.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.last.name.matches("t_uk_b_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected CreateTable, but got: $other") + } + } + } + + test("Replace table with unnamed unique constraint") { + Seq( + "REPLACE TABLE t (a INT, b STRING, UNIQUE (a))", + "REPLACE TABLE t (a INT UNIQUE, b STRING)" + ).foreach { sql => + val plan = parsePlan(sql) + plan match { + case c: ReplaceTable => + val tableSpec = c.tableSpec.asInstanceOf[UnresolvedTableSpec] + assert(tableSpec.constraints.size == 1) + assert(tableSpec.constraints.head.isInstanceOf[UniqueConstraint]) + assert(tableSpec.constraints.head.name.matches("t_uk_a_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected ReplaceTable, but got: $other") + } + } + } + + test("Add unnamed unique constraint") { + val sql = + """ + |ALTER TABLE a.b.c ADD UNIQUE (d) + |""".stripMargin + val plan = parsePlan(sql) + plan match { + case a: AddConstraint => + val table = a.table.asInstanceOf[UnresolvedTable] + assert(table.multipartIdentifier == Seq("a", "b", "c")) + assert(a.tableConstraint.isInstanceOf[UniqueConstraint]) + assert(a.tableConstraint.name.matches("c_uk_d_[0-9a-zA-Z]{6}")) + + case other => + fail(s"Expected AddConstraint, but got: $other") + } } - test("Create table with unique constraint - column level") { - val sql = "CREATE TABLE t (a INT UNIQUE, b STRING) USING parquet" + test("Create table with named unique constraint - table level") { + val sql = "CREATE TABLE t (a INT, b STRING, CONSTRAINT uk1 UNIQUE (a)) USING parquet" val constraint = UniqueConstraint( - columns = Seq("a") + columns = Seq("a"), + name = "uk1" ) val constraints = Seq(constraint) verifyConstraints(sql, constraints) @@ -68,16 +140,10 @@ class UniqueConstraintParseSuite extends ConstraintParseSuiteBase { verifyConstraints(sql, constraints) } - test("Create table with multiple unique constraints") { - val sql = "CREATE TABLE t (a INT UNIQUE, b STRING, UNIQUE (b)) USING parquet" - val constraint1 = UniqueConstraint(columns = Seq("a")) - val constraint2 = UniqueConstraint(columns = Seq("b")) - val constraints = Seq(constraint1, constraint2) - verifyConstraints(sql, constraints) - } - test("Add unique constraint") { - Seq(("", null), ("CONSTRAINT uk1", "uk1")).foreach { case (constraintName, expectedName) => + Seq( + ("consTrainT abcdEF", "abcdEF"), + ("CONSTRAINT uk1", "uk1")).foreach { case (constraintName, expectedName) => val sql = s""" |ALTER TABLE a.b.c ADD $constraintName UNIQUE (email, username) @@ -158,7 +224,6 @@ class UniqueConstraintParseSuite extends ConstraintParseSuiteBase { } } - test("ENFORCED is not supported for unique -- create table with unnamed constraint") { enforcedConstraintCharacteristics.foreach { case (c1, c2, _) => val characteristic = if (c2.isEmpty) {