diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 336d3d65d0dd..24fff3ad563a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType object SessionCatalog { val DEFAULT_DATABASE = "default" @@ -188,19 +188,6 @@ class SessionCatalog( } } - private def checkDuplication(fields: Seq[StructField]): Unit = { - val columnNames = if (conf.caseSensitiveAnalysis) { - fields.map(_.name) - } else { - fields.map(_.name.toLowerCase) - } - if (columnNames.distinct.length != columnNames.length) { - val duplicateColumns = columnNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => x - } - throw new AnalysisException(s"Found duplicate column(s): ${duplicateColumns.mkString(", ")}") - } - } // ---------------------------------------------------------------------------- // Databases // ---------------------------------------------------------------------------- @@ -353,7 +340,6 @@ class SessionCatalog( val tableIdentifier = TableIdentifier(table, Some(db)) requireDbExists(db) requireTableExists(tableIdentifier) - checkDuplication(newSchema) val catalogTable = externalCatalog.getTable(db, table) val oldSchema = catalogTable.schema diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index e881685ce626..41ca270095ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.util -import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.types.StructType /** @@ -25,29 +27,63 @@ import org.apache.spark.internal.Logging * * TODO: Merge this file with [[org.apache.spark.ml.util.SchemaUtils]]. */ -private[spark] object SchemaUtils extends Logging { +private[spark] object SchemaUtils { /** - * Checks if input column names have duplicate identifiers. Prints a warning message if + * Checks if an input schema has duplicate column names. This throws an exception if the + * duplication exists. + * + * @param schema schema to check + * @param colType column type name, used in an exception message + * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not + */ + def checkSchemaColumnNameDuplication( + schema: StructType, colType: String, caseSensitiveAnalysis: Boolean = false): Unit = { + checkColumnNameDuplication(schema.map(_.name), colType, caseSensitiveAnalysis) + } + + // Returns true if a given resolver is case-sensitive + private def isCaseSensitiveAnalysis(resolver: Resolver): Boolean = { + if (resolver == caseSensitiveResolution) { + true + } else if (resolver == caseInsensitiveResolution) { + false + } else { + sys.error("A resolver to check if two identifiers are equal must be " + + "`caseSensitiveResolution` or `caseInsensitiveResolution` in o.a.s.sql.catalyst.") + } + } + + /** + * Checks if input column names have duplicate identifiers. This throws an exception if * the duplication exists. * * @param columnNames column names to check - * @param colType column type name, used in a warning message + * @param colType column type name, used in an exception message + * @param resolver resolver used to determine if two identifiers are equal + */ + def checkColumnNameDuplication( + columnNames: Seq[String], colType: String, resolver: Resolver): Unit = { + checkColumnNameDuplication(columnNames, colType, isCaseSensitiveAnalysis(resolver)) + } + + /** + * Checks if input column names have duplicate identifiers. This throws an exception if + * the duplication exists. + * + * @param columnNames column names to check + * @param colType column type name, used in an exception message * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not */ def checkColumnNameDuplication( columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = { - val names = if (caseSensitiveAnalysis) { - columnNames - } else { - columnNames.map(_.toLowerCase) - } + val names = if (caseSensitiveAnalysis) columnNames else columnNames.map(_.toLowerCase) if (names.distinct.length != names.length) { val duplicateColumns = names.groupBy(identity).collect { case (x, ys) if ys.length > 1 => s"`$x`" } - logWarning(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}. " + - "You might need to assign different column names.") + throw new AnalysisException( + s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}") } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala new file mode 100644 index 000000000000..a25be2fe61db --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.types.StructType + +class SchemaUtilsSuite extends SparkFunSuite { + + private def resolver(caseSensitiveAnalysis: Boolean): Resolver = { + if (caseSensitiveAnalysis) { + caseSensitiveResolution + } else { + caseInsensitiveResolution + } + } + + Seq((true, ("a", "a"), ("b", "b")), (false, ("a", "A"), ("b", "B"))).foreach { + case (caseSensitive, (a0, a1), (b0, b1)) => + + val testType = if (caseSensitive) "case-sensitive" else "case-insensitive" + test(s"Check column name duplication in $testType cases") { + def checkExceptionCases(schemaStr: String, duplicatedColumns: Seq[String]): Unit = { + val expectedErrorMsg = "Found duplicate column(s) in SchemaUtilsSuite: " + + duplicatedColumns.map(c => s"`${c.toLowerCase}`").mkString(", ") + val schema = StructType.fromDDL(schemaStr) + var msg = intercept[AnalysisException] { + SchemaUtils.checkSchemaColumnNameDuplication( + schema, "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + }.getMessage + assert(msg.contains(expectedErrorMsg)) + msg = intercept[AnalysisException] { + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", resolver(caseSensitive)) + }.getMessage + assert(msg.contains(expectedErrorMsg)) + msg = intercept[AnalysisException] { + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + }.getMessage + assert(msg.contains(expectedErrorMsg)) + } + + checkExceptionCases(s"$a0 INT, b INT, $a1 INT", a0 :: Nil) + checkExceptionCases(s"$a0 INT, b INT, $a1 INT, $a0 INT", a0 :: Nil) + checkExceptionCases(s"$a0 INT, $b0 INT, $a1 INT, $a0 INT, $b1 INT", b0 :: a0 :: Nil) + } + } + + test("Check no exception thrown for valid schemas") { + def checkNoExceptionCases(schemaStr: String, caseSensitive: Boolean): Unit = { + val schema = StructType.fromDDL(schemaStr) + SchemaUtils.checkSchemaColumnNameDuplication( + schema, "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", resolver(caseSensitive)) + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + } + + checkNoExceptionCases("a INT, b INT, c INT", caseSensitive = true) + checkNoExceptionCases("Aa INT, b INT, aA INT", caseSensitive = true) + + checkNoExceptionCases("a INT, b INT, c INT", caseSensitive = false) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 729bd39d821c..04b2534ca5eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.command import java.net.URI -import org.apache.hadoop.fs.Path - import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8ded1060f7bf..fa50d1272241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.command import java.io.File import java.net.URI import java.nio.file.FileSystems -import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import scala.util.Try -import org.apache.commons.lang3.StringEscapeUtils import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -42,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils /** @@ -202,6 +201,11 @@ case class AlterTableAddColumnsCommand( // make sure any partition columns are at the end of the fields val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema + + SchemaUtils.checkColumnNameDuplication( + reorderedSchema.map(_.name), "in the table definition of " + table.identifier, + conf.caseSensitiveAnalysis) + catalog.alterTableSchema( table, catalogTable.schema.copy(fields = reorderedSchema.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index a6d56ca91a3e..ffdfd527fa70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.sql.util.SchemaUtils /** @@ -355,15 +356,15 @@ object ViewHelper { properties: Map[String, String], session: SparkSession, analyzedPlan: LogicalPlan): Map[String, String] = { + val queryOutput = analyzedPlan.schema.fieldNames + // Generate the query column names, throw an AnalysisException if there exists duplicate column // names. - val queryOutput = analyzedPlan.schema.fieldNames - assert(queryOutput.distinct.size == queryOutput.size, - s"The view output ${queryOutput.mkString("(", ",", ")")} contains duplicate column name.") + SchemaUtils.checkColumnNameDuplication( + queryOutput, "in the view definition", session.sessionState.conf.resolver) // Generate the view default database name. val viewDefaultDatabase = session.sessionState.catalog.getCurrentDatabase - removeQueryColumnNames(properties) ++ generateViewDefaultDatabase(viewDefaultDatabase) ++ generateQueryColumnNames(queryOutput) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 75e530607570..d36a04f1fff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -87,6 +87,14 @@ case class DataSource( lazy val providingClass: Class[_] = DataSource.lookupDataSource(className) lazy val sourceInfo: SourceInfo = sourceSchema() private val caseInsensitiveOptions = CaseInsensitiveMap(options) + private val equality = sparkSession.sessionState.conf.resolver + + bucketSpec.map { bucket => + SchemaUtils.checkColumnNameDuplication( + bucket.bucketColumnNames, "in the bucket definition", equality) + SchemaUtils.checkColumnNameDuplication( + bucket.sortColumnNames, "in the sort definition", equality) + } /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer @@ -132,7 +140,6 @@ case class DataSource( // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource val resolved = tempFileIndex.partitionSchema.map { partitionField => - val equality = sparkSession.sessionState.conf.resolver // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( partitionField) @@ -146,7 +153,6 @@ case class DataSource( inferredPartitions } else { val partitionFields = partitionColumns.map { partitionColumn => - val equality = sparkSession.sessionState.conf.resolver userSpecifiedSchema.flatMap(_.find(c => equality(c.name, partitionColumn))).orElse { val inferredPartitions = tempFileIndex.partitionSchema val inferredOpt = inferredPartitions.find(p => equality(p.name, partitionColumn)) @@ -172,7 +178,6 @@ case class DataSource( } val dataSchema = userSpecifiedSchema.map { schema => - val equality = sparkSession.sessionState.conf.resolver StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name)))) }.orElse { format.inferSchema( @@ -184,9 +189,18 @@ case class DataSource( s"Unable to infer schema for $format. It must be specified manually.") } - SchemaUtils.checkColumnNameDuplication( - (dataSchema ++ partitionSchema).map(_.name), "in the data schema and the partition schema", - sparkSession.sessionState.conf.caseSensitiveAnalysis) + // We just print a waring message if the data schema and partition schema have the duplicate + // columns. This is because we allow users to do so in the previous Spark releases and + // we have the existing tests for the cases (e.g., `ParquetHadoopFsRelationSuite`). + // See SPARK-18108 and SPARK-21144 for related discussions. + try { + SchemaUtils.checkColumnNameDuplication( + (dataSchema ++ partitionSchema).map(_.name), + "in the data schema and the partition schema", + equality) + } catch { + case e: AnalysisException => logWarning(e.getMessage) + } (dataSchema, partitionSchema) } @@ -391,6 +405,23 @@ case class DataSource( s"$className is not a valid Spark SQL Data Source.") } + relation match { + case hs: HadoopFsRelation => + SchemaUtils.checkColumnNameDuplication( + hs.dataSchema.map(_.name), + "in the data schema", + equality) + SchemaUtils.checkColumnNameDuplication( + hs.partitionSchema.map(_.name), + "in the partition schema", + equality) + case _ => + SchemaUtils.checkColumnNameDuplication( + relation.schema.map(_.name), + "in the data schema", + equality) + } + relation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 0031567d3d28..c1bcfb861078 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -21,7 +21,6 @@ import java.io.IOException import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} @@ -30,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.util.SchemaUtils /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. @@ -64,13 +63,10 @@ case class InsertIntoHadoopFsRelationCommand( assert(children.length == 1) // Most formats don't do well with duplicate columns, so lets not allow that - if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) { - val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s): $duplicateColumns found, " + - "cannot save to file.") - } + SchemaUtils.checkSchemaColumnNameDuplication( + query.schema, + s"when inserting into $outputPath", + sparkSession.sessionState.conf.caseSensitiveAnalysis) val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) val fs = outputPath.getFileSystem(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f61c673baaa5..92358da6d6c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -301,13 +302,8 @@ object PartitioningUtils { normalizedKey -> value } - if (normalizedPartSpec.map(_._1).distinct.length != normalizedPartSpec.length) { - val duplicateColumns = normalizedPartSpec.map(_._1).groupBy(identity).collect { - case (x, ys) if ys.length > 1 => x - } - throw new AnalysisException(s"Found duplicated columns in partition specification: " + - duplicateColumns.mkString(", ")) - } + SchemaUtils.checkColumnNameDuplication( + normalizedPartSpec.map(_._1), "in the partition schema", resolver) normalizedPartSpec.toMap } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 55b2539c1338..bbe9024f13a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.NextIterator @@ -749,14 +750,8 @@ object JdbcUtils extends Logging { val nameEquality = df.sparkSession.sessionState.conf.resolver // checks duplicate columns in the user specified column types. - userSchema.fieldNames.foreach { col => - val duplicatesCols = userSchema.fieldNames.filter(nameEquality(_, col)) - if (duplicatesCols.size >= 2) { - throw new AnalysisException( - "Found duplicate column(s) in createTableColumnTypes option value: " + - duplicatesCols.mkString(", ")) - } - } + SchemaUtils.checkColumnNameDuplication( + userSchema.map(_.name), "in the createTableColumnTypes option value", nameEquality) // checks if user specified column names exist in the DataFrame schema userSchema.fieldNames.foreach { col => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 5a92a71d19e7..8b7c2709afde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -59,9 +59,7 @@ abstract class JsonDataSource extends Serializable { inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Option[StructType] = { if (inputPaths.nonEmpty) { - val jsonSchema = infer(sparkSession, inputPaths, parsedOptions) - checkConstraints(jsonSchema) - Some(jsonSchema) + Some(infer(sparkSession, inputPaths, parsedOptions)) } else { None } @@ -71,17 +69,6 @@ abstract class JsonDataSource extends Serializable { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType - - /** Constraints to be imposed on schema to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") - } - } } object JsonDataSource { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3f4a78580f1e..41d40aa926fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.{AtomicType, StructType} +import org.apache.spark.sql.util.SchemaUtils /** * Try to replaces [[UnresolvedRelation]]s if the plan is for direct query on files. @@ -222,12 +223,10 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi } private def normalizeCatalogTable(schema: StructType, table: CatalogTable): CatalogTable = { - val columnNames = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { - schema.map(_.name) - } else { - schema.map(_.name.toLowerCase) - } - checkDuplication(columnNames, "table definition of " + table.identifier) + SchemaUtils.checkSchemaColumnNameDuplication( + schema, + "in the table definition of " + table.identifier, + sparkSession.sessionState.conf.caseSensitiveAnalysis) val normalizedPartCols = normalizePartitionColumns(schema, table) val normalizedBucketSpec = normalizeBucketSpec(schema, table) @@ -253,7 +252,10 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi partCols = table.partitionColumnNames, resolver = sparkSession.sessionState.conf.resolver) - checkDuplication(normalizedPartitionCols, "partition") + SchemaUtils.checkColumnNameDuplication( + normalizedPartitionCols, + "in the partition schema", + sparkSession.sessionState.conf.resolver) if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) { if (DDLUtils.isHiveTable(table)) { @@ -283,8 +285,15 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi tableCols = schema.map(_.name), bucketSpec = bucketSpec, resolver = sparkSession.sessionState.conf.resolver) - checkDuplication(normalizedBucketSpec.bucketColumnNames, "bucket") - checkDuplication(normalizedBucketSpec.sortColumnNames, "sort") + + SchemaUtils.checkColumnNameDuplication( + normalizedBucketSpec.bucketColumnNames, + "in the bucket definition", + sparkSession.sessionState.conf.resolver) + SchemaUtils.checkColumnNameDuplication( + normalizedBucketSpec.sortColumnNames, + "in the sort definition", + sparkSession.sessionState.conf.resolver) normalizedBucketSpec.sortColumnNames.map(schema(_)).map(_.dataType).foreach { case dt if RowOrdering.isOrderable(dt) => // OK @@ -297,15 +306,6 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi } } - private def checkDuplication(colNames: Seq[String], colType: String): Unit = { - if (colNames.distinct.length != colNames.length) { - val duplicateColumns = colNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => x - } - failAnalysis(s"Found duplicate column(s) in $colType: ${duplicateColumns.mkString(", ")}") - } - } - private def failAnalysis(msg: String) = throw new AnalysisException(msg) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9ea9951c24ef..3c08b40a2652 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1123,7 +1123,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") .write.format("parquet").save("temp") } - assert(e.getMessage.contains("Duplicate column(s)")) + assert(e.getMessage.contains("Found duplicate column(s) when inserting into")) assert(e.getMessage.contains("column1")) assert(!e.getMessage.contains("column2")) @@ -1133,7 +1133,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .toDF("column1", "column2", "column3", "column1", "column3") .write.format("json").save("temp") } - assert(f.getMessage.contains("Duplicate column(s)")) + assert(f.getMessage.contains("Found duplicate column(s) when inserting into")) assert(f.getMessage.contains("column1")) assert(f.getMessage.contains("column3")) assert(!f.getMessage.contains("column2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 5c40d8bb4b1e..5c0a6aa724bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -436,16 +436,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create table - duplicate column names in the table definition") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int, a string) USING json") - } - assert(e.message == "Found duplicate column(s) in table definition of `tbl`: a") - - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - val e2 = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int, A string) USING json") + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT, $c1 INT) USING parquet") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the table definition of `t`")) } - assert(e2.message == "Found duplicate column(s) in table definition of `tbl`: a") } } @@ -466,17 +463,33 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create table - column repeated in partition columns") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int) USING json PARTITIONED BY (a, a)") + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT) USING parquet PARTITIONED BY ($c0, $c1)") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the partition schema")) + } } - assert(e.message == "Found duplicate column(s) in partition: a") } - test("create table - column repeated in bucket columns") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int) USING json CLUSTERED BY (a, a) INTO 4 BUCKETS") + test("create table - column repeated in bucket/sort columns") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT) USING parquet CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the bucket definition")) + + errMsg = intercept[AnalysisException] { + sql(s""" + |CREATE TABLE t($c0 INT, col INT) USING parquet CLUSTERED BY (col) + | SORTED BY ($c0, $c1) INTO 2 BUCKETS + """.stripMargin) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the sort definition")) + } } - assert(e.message == "Found duplicate column(s) in bucket: a") } test("Refresh table after changing the data source table partitioning") { @@ -528,6 +541,17 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create view - duplicate column names in the view definition") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE VIEW t AS SELECT * FROM VALUES (1, 1) AS t($c0, $c1)") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the view definition")) + } + } + } + test("Alter/Describe Database") { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 92f50a095f19..2334d5ae32dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.{Date, DriverManager, Timestamp} +import java.sql.DriverManager import java.util.Properties import scala.collection.JavaConverters.propertiesAsScalaMapConverter @@ -479,7 +479,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .jdbc(url1, "TEST.USERDBTYPETEST", properties) }.getMessage() assert(msg.contains( - "Found duplicate column(s) in createTableColumnTypes option value: name, NaMe")) + "Found duplicate column(s) in the createTableColumnTypes option value: `name`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 0f97fd78d2ff..308c5079c44b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -21,11 +21,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.test.SharedSQLContext -class ResolvedDataSourceSuite extends SparkFunSuite { +class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { private def getProvidingClass(name: String): Class[_] = DataSource( - sparkSession = null, + sparkSession = spark, className = name, options = Map(DateTimeUtils.TIMEZONE_OPTION -> DateTimeUtils.defaultTimeZone().getID) ).providingClass diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index bb6a27803bb2..6676099d426b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -352,4 +353,40 @@ class FileStreamSinkSuite extends StreamTest { assertAncestorIsNotMetadataDirectory(s"/a/b/c") assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra") } + + test("SPARK-20460 Check name duplication in schema") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + var query: StreamingQuery = null + try { + query = + df.writeStream + .option("checkpointLocation", checkpointDir) + .format("json") + .start(outputDir) + + inputData.addData((1, 1)) + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + if (query != null) { + query.stop() + } + } + + val errorMsg = intercept[AnalysisException] { + spark.read.schema(s"$c0 INT, $c1 INT").json(outputDir).as[(Int, Int)] + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema: ")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 306aecb5bbc8..569bac156b53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -687,4 +688,91 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be testRead(spark.read.schema(userSchemaString).text(dir, dir), data ++ data, userSchema) testRead(spark.read.schema(userSchemaString).text(Seq(dir, dir): _*), data ++ data, userSchema) } + + test("SPARK-20460 Check name duplication in buckets") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF("col", c0).write.bucketBy(2, c0, c1).saveAsTable("t") + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the bucket definition")) + + errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF("col", c0).write.bucketBy(2, "col").sortBy(c0, c1).saveAsTable("t") + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the sort definition")) + } + } + } + + test("SPARK-20460 Check name duplication in schema") { + def checkWriteDataColumnDuplication( + format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF(colName0, colName1).write.format(format).mode("overwrite") + .save(tempDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) when inserting into")) + } + + def checkReadUserSpecifiedDataColumnDuplication( + df: DataFrame, format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val testDir = Utils.createTempDir(tempDir.getAbsolutePath) + df.write.format(format).mode("overwrite").save(testDir.getAbsolutePath) + val errorMsg = intercept[AnalysisException] { + spark.read.format(format).schema(s"$colName0 INT, $colName1 INT") + .load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema:")) + } + + def checkReadPartitionColumnDuplication( + format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val testDir = Utils.createTempDir(tempDir.getAbsolutePath) + Seq(1).toDF("col").write.format(format).mode("overwrite") + .save(s"${testDir.getAbsolutePath}/$colName0=1/$colName1=1") + val errorMsg = intercept[AnalysisException] { + spark.read.format(format).load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the partition schema:")) + } + + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + withTempDir { src => + // Check CSV format + checkWriteDataColumnDuplication("csv", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "csv", c0, c1, src) + // If `inferSchema` is true, a CSV format is duplicate-safe (See SPARK-16896) + var testDir = Utils.createTempDir(src.getAbsolutePath) + Seq("a,a", "1,1").toDF().coalesce(1).write.mode("overwrite").text(testDir.getAbsolutePath) + val df = spark.read.format("csv").option("inferSchema", true).option("header", true) + .load(testDir.getAbsolutePath) + checkAnswer(df, Row(1, 1)) + checkReadPartitionColumnDuplication("csv", c0, c1, src) + + // Check JSON format + checkWriteDataColumnDuplication("json", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "json", c0, c1, src) + // Inferred schema cases + testDir = Utils.createTempDir(src.getAbsolutePath) + Seq(s"""{"$c0":3, "$c1":5}""").toDF().write.mode("overwrite") + .text(testDir.getAbsolutePath) + val errorMsg = intercept[AnalysisException] { + spark.read.format("json").option("inferSchema", true).load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema:")) + checkReadPartitionColumnDuplication("json", c0, c1, src) + + // Check Parquet format + checkWriteDataColumnDuplication("parquet", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "parquet", c0, c1, src) + checkReadPartitionColumnDuplication("parquet", c0, c1, src) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 31fa3d244746..12daf3af11ab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -345,7 +345,7 @@ class HiveDDLSuite val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int) PARTITIONED BY (a string)") } - assert(e.message == "Found duplicate column(s) in table definition of `default`.`tbl`: a") + assert(e.message == "Found duplicate column(s) in the table definition of `default`.`tbl`: `a`") } test("add/drop partition with location - managed table") {