diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index dae160f1bbb1..7611e1c2e268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -34,6 +34,9 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter +import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.types._ import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} @@ -848,4 +851,22 @@ object DDLUtils { } } } + + private[sql] def checkDataSchemaFieldNames(table: CatalogTable): Unit = { + table.provider.foreach { + _.toLowerCase(Locale.ROOT) match { + case HIVE_PROVIDER => + val serde = table.storage.serde + if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) { + OrcFileFormat.checkFieldNames(table.dataSchema) + } else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde || + serde == Some("parquet.hive.serde.ParquetHiveSerDe")) { + ParquetSchemaConverter.checkFieldNames(table.dataSchema) + } + case "parquet" => ParquetSchemaConverter.checkFieldNames(table.dataSchema) + case "orc" => OrcFileFormat.checkFieldNames(table.dataSchema) + case _ => + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 694d517668a2..1dddc1ca324b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -201,13 +201,14 @@ case class AlterTableAddColumnsCommand( // make sure any partition columns are at the end of the fields val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema + val newSchema = catalogTable.schema.copy(fields = reorderedSchema.toArray) SchemaUtils.checkColumnNameDuplication( reorderedSchema.map(_.name), "in the table definition of " + table.identifier, conf.caseSensitiveAnalysis) + DDLUtils.checkDataSchemaFieldNames(catalogTable.copy(schema = newSchema)) - catalog.alterTableSchema( - table, catalogTable.schema.copy(fields = reorderedSchema.toArray)) + catalog.alterTableSchema(table, newSchema) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 0deac1984bd6..5d6223dffd28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -130,10 +130,12 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => + DDLUtils.checkDataSchemaFieldNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => + DDLUtils.checkDataSchemaFieldNames(tableDesc.copy(schema = query.schema)) CreateDataSourceTableAsSelectCommand(tableDesc, mode, query) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala new file mode 100644 index 000000000000..2eeb0065455f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.orc.TypeDescription + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.StructType + +private[sql] object OrcFileFormat { + private def checkFieldName(name: String): Unit = { + try { + TypeDescription.fromString(s"struct<$name:int>") + } catch { + case _: IllegalArgumentException => + throw new AnalysisException( + s"""Column name "$name" contains invalid character(s). + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ").trim) + } + } + + def checkFieldNames(schema: StructType): StructType = { + schema.fieldNames.foreach(checkFieldName) + schema + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 0b805e436288..b3781cfc4a60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -556,7 +556,7 @@ private[parquet] class ParquetSchemaConverter( } } -private[parquet] object ParquetSchemaConverter { +private[sql] object ParquetSchemaConverter { val SPARK_PARQUET_SCHEMA_NAME = "spark_schema" val EMPTY_MESSAGE: MessageType = diff --git a/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql b/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql index 1e02c2f045ea..521018e94e50 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql @@ -2,9 +2,9 @@ CREATE DATABASE showdb; USE showdb; -CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet; +CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json; CREATE TABLE showcolumn2 (price int, qty int, year int, month int) USING parquet partitioned by (year, month); -CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet; +CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json; CREATE GLOBAL TEMP VIEW showColumn4 AS SELECT 1 as col1, 'abc' as `col 5`; diff --git a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out index 05c3a083ee3b..71d6e120e894 100644 --- a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out @@ -19,7 +19,7 @@ struct<> -- !query 2 -CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet +CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json -- !query 2 schema struct<> -- !query 2 output @@ -35,7 +35,7 @@ struct<> -- !query 4 -CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet +CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json -- !query 4 schema struct<> -- !query 4 output diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ae1e7e72e8c3..47203a80c37b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -151,9 +151,11 @@ object HiveAnalysis extends Rule[LogicalPlan] { InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => + DDLUtils.checkDataSchemaFieldNames(tableDesc) CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => + DDLUtils.checkDataSchemaFieldNames(tableDesc) CreateHiveTableAsSelectCommand(tableDesc, query, mode) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index d2a6ef7b2b37..85a6a77cedc4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2000,4 +2000,38 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assert(setOfPath.size() == pathSizeToDeleteOnExit) } } + + test("SPARK-21912 ORC/Parquet table should not create invalid column names") { + Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name => + withTable("t21912") { + Seq("ORC", "PARQUET").foreach { source => + val m = intercept[AnalysisException] { + sql(s"CREATE TABLE t21912(`col$name` INT) USING $source") + }.getMessage + assert(m.contains(s"contains invalid character(s)")) + + val m2 = intercept[AnalysisException] { + sql(s"CREATE TABLE t21912 USING $source AS SELECT 1 `col$name`") + }.getMessage + assert(m2.contains(s"contains invalid character(s)")) + + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { + val m3 = intercept[AnalysisException] { + sql(s"CREATE TABLE t21912(`col$name` INT) USING hive OPTIONS (fileFormat '$source')") + }.getMessage + assert(m3.contains(s"contains invalid character(s)")) + } + } + + // TODO: After SPARK-21929, we need to check ORC, too. + Seq("PARQUET").foreach { source => + sql(s"CREATE TABLE t21912(`col` INT) USING $source") + val m = intercept[AnalysisException] { + sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)") + }.getMessage + assert(m.contains(s"contains invalid character(s)")) + } + } + } + } }