Skip to content

Commit c6e9ab6

Browse files
committed
Add ALTER TABLE and address comments.
1 parent 8ee87dd commit c6e9ab6

File tree

4 files changed

+27
-20
lines changed

4 files changed

+27
-20
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -853,22 +853,17 @@ object DDLUtils {
853853
}
854854

855855
private[sql] def checkFieldNames(table: CatalogTable): Unit = {
856-
table.provider.get.toLowerCase(Locale.ROOT) match {
857-
case "hive" =>
858-
val serde = table.storage.serde
859-
if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) {
860-
OrcFileFormat.checkFieldNames(table.dataSchema)
861-
} else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde) {
862-
ParquetSchemaConverter.checkFieldNames(table.dataSchema)
863-
}
864-
865-
case "parquet" =>
866-
ParquetSchemaConverter.checkFieldNames(table.dataSchema)
867-
868-
case "orc" =>
869-
OrcFileFormat.checkFieldNames(table.dataSchema)
870-
871-
case _ =>
856+
val serde = table.storage.serde
857+
if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) {
858+
OrcFileFormat.checkFieldNames(table.dataSchema)
859+
} else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde) {
860+
ParquetSchemaConverter.checkFieldNames(table.dataSchema)
861+
} else {
862+
table.provider.get.toLowerCase(Locale.ROOT) match {
863+
case "parquet" => ParquetSchemaConverter.checkFieldNames(table.dataSchema)
864+
case "orc" => OrcFileFormat.checkFieldNames(table.dataSchema)
865+
case _ =>
866+
}
872867
}
873868
}
874869
}

sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ case class AlterTableAddColumnsCommand(
206206
reorderedSchema.map(_.name), "in the table definition of " + table.identifier,
207207
conf.caseSensitiveAnalysis)
208208

209+
val newDataSchema = StructType(catalogTable.dataSchema ++ columns)
210+
DDLUtils.checkFieldNames(catalogTable.copy(schema = newDataSchema))
211+
209212
catalog.alterTableSchema(
210213
table, catalogTable.schema.copy(fields = reorderedSchema.toArray))
211214

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ private[sql] object OrcFileFormat {
2929
} catch {
3030
case _: IllegalArgumentException =>
3131
throw new AnalysisException(
32-
s"""Attribute name "$name" contains invalid character(s).
32+
s"""Column name "$name" contains invalid character(s).
3333
|Please use alias to rename it.
3434
""".stripMargin.split("\n").mkString(" ").trim)
3535
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2001,9 +2001,9 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
20012001
}
20022002
}
20032003

2004-
test("SPARK-21912 Creating ORC/Parquet datasource table should check invalid column names") {
2005-
withTable("t21912") {
2006-
Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name =>
2004+
test("SPARK-21912 ORC/Parquet table should not create invalid column names") {
2005+
Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name =>
2006+
withTable("t21912") {
20072007
Seq("ORC", "PARQUET").foreach { source =>
20082008
val m = intercept[AnalysisException] {
20092009
sql(s"CREATE TABLE t21912(`col$name` INT) USING $source")
@@ -2022,6 +2022,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
20222022
assert(m3.contains(s"contains invalid character(s)"))
20232023
}
20242024
}
2025+
2026+
// TODO: After SPARK-21929, we need to check ORC, too.
2027+
Seq("PARQUET").foreach { source =>
2028+
sql(s"CREATE TABLE t21912(`col` INT) USING $source")
2029+
val m = intercept[AnalysisException] {
2030+
sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)")
2031+
}.getMessage
2032+
assert(m.contains(s"contains invalid character(s)"))
2033+
}
20252034
}
20262035
}
20272036
}

0 commit comments

Comments
 (0)