Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1e647ee
Share code to check column name duplication
maropu Apr 25, 2017
4467077
Apply reviews
maropu Jun 13, 2017
33ab217
Make code more consistent
maropu Jun 13, 2017
d8efb9d
Apply review comments
maropu Jun 16, 2017
11d1818
Apply xiao's reviews
maropu Jun 16, 2017
22e1e4f
Apply more xiao's reviews
maropu Jun 17, 2017
743a069
Replace map with foreach
maropu Jun 20, 2017
f6eab2d
Add tests for data schema + parititon schema
maropu Jun 20, 2017
09da8d6
Drop name dplication checks in HiveMetastoreCatalog.scala
maropu Jun 20, 2017
6d03f31
Modify exception messages
maropu Jun 20, 2017
a0b9b05
Revert logic to check name duplication
maropu Jun 20, 2017
91b6424
Add tests for write paths
maropu Jun 21, 2017
37ad3f3
Add tests for stream sink paths
maropu Jun 21, 2017
d0d9d3e
Burhs up code and adds more tests
maropu Jun 25, 2017
cbe9c71
Apply reviews
maropu Jun 26, 2017
c69270f
Apply more comments
maropu Jun 27, 2017
af959f6
Add more tests in create.sql
maropu Jun 27, 2017
8d3e10a
Move duplication checks in constructor
maropu Jun 29, 2017
9b386d5
Brush up code
maropu Jun 30, 2017
a878510
[WIP] Add DataSourceValidator trait to validate schema in write path
maropu Jul 3, 2017
be20127
Revert "Brush up code"
maropu Jul 3, 2017
f41bf80
Fix more issues
maropu Jul 4, 2017
0526391
Revert DataSourceValidator
maropu Jul 4, 2017
9e199bc
Add the check for external relation providers
maropu Jul 4, 2017
1ae132d
[WIP] Handle DataSource name duplication in one place
maropu Jul 5, 2017
5c29a75
Fix more
maropu Jul 6, 2017
5ed2c0d
Move some tests to DDLSuite
maropu Jul 7, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View}
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.StructType

object SessionCatalog {
val DEFAULT_DATABASE = "default"
Expand Down Expand Up @@ -188,19 +188,6 @@ class SessionCatalog(
}
}

private def checkDuplication(fields: Seq[StructField]): Unit = {
val columnNames = if (conf.caseSensitiveAnalysis) {
fields.map(_.name)
} else {
fields.map(_.name.toLowerCase)
}
if (columnNames.distinct.length != columnNames.length) {
val duplicateColumns = columnNames.groupBy(identity).collect {
case (x, ys) if ys.length > 1 => x
}
throw new AnalysisException(s"Found duplicate column(s): ${duplicateColumns.mkString(", ")}")
}
}
// ----------------------------------------------------------------------------
// Databases
// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -353,7 +340,6 @@ class SessionCatalog(
val tableIdentifier = TableIdentifier(table, Some(db))
requireDbExists(db)
requireTableExists(tableIdentifier)
checkDuplication(newSchema)

val catalogTable = externalCatalog.getTable(db, table)
val oldSchema = catalogTable.schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,73 @@

package org.apache.spark.sql.util

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types.StructType


/**
* Utils for handling schemas.
*
* TODO: Merge this file with [[org.apache.spark.ml.util.SchemaUtils]].
*/
private[spark] object SchemaUtils extends Logging {
private[spark] object SchemaUtils {

/**
* Checks if input column names have duplicate identifiers. Prints a warning message if
* Checks if an input schema has duplicate column names. This throws an exception if the
* duplication exists.
*
* @param schema schema to check
* @param colType column type name, used in an exception message
* @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not
*/
def checkSchemaColumnNameDuplication(
schema: StructType, colType: String, caseSensitiveAnalysis: Boolean = false): Unit = {
checkColumnNameDuplication(schema.map(_.name), colType, caseSensitiveAnalysis)
}

// Returns true if a given resolver is case-sensitive
private def isCaseSensitiveAnalysis(resolver: Resolver): Boolean = {
if (resolver == caseSensitiveResolution) {
true
} else if (resolver == caseInsensitiveResolution) {
false
} else {
sys.error("A resolver to check if two identifiers are equal must be " +
"`caseSensitiveResolution` or `caseInsensitiveResolution` in o.a.s.sql.catalyst.")
}
}

/**
* Checks if input column names have duplicate identifiers. This throws an exception if
* the duplication exists.
*
* @param columnNames column names to check
* @param colType column type name, used in a warning message
* @param colType column type name, used in an exception message
* @param resolver resolver used to determine if two identifiers are equal
*/
def checkColumnNameDuplication(
columnNames: Seq[String], colType: String, resolver: Resolver): Unit = {
checkColumnNameDuplication(columnNames, colType, isCaseSensitiveAnalysis(resolver))
}

/**
* Checks if input column names have duplicate identifiers. This throws an exception if
* the duplication exists.
*
* @param columnNames column names to check
* @param colType column type name, used in an exception message
* @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not
*/
def checkColumnNameDuplication(
columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = {
val names = if (caseSensitiveAnalysis) {
columnNames
} else {
columnNames.map(_.toLowerCase)
}
val names = if (caseSensitiveAnalysis) columnNames else columnNames.map(_.toLowerCase)
if (names.distinct.length != names.length) {
val duplicateColumns = names.groupBy(identity).collect {
case (x, ys) if ys.length > 1 => s"`$x`"
}
logWarning(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}. " +
"You might need to assign different column names.")
throw new AnalysisException(
s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types.StructType

class SchemaUtilsSuite extends SparkFunSuite {

private def resolver(caseSensitiveAnalysis: Boolean): Resolver = {
if (caseSensitiveAnalysis) {
caseSensitiveResolution
} else {
caseInsensitiveResolution
}
}

Seq((true, ("a", "a"), ("b", "b")), (false, ("a", "A"), ("b", "B"))).foreach {
case (caseSensitive, (a0, a1), (b0, b1)) =>

val testType = if (caseSensitive) "case-sensitive" else "case-insensitive"
test(s"Check column name duplication in $testType cases") {
def checkExceptionCases(schemaStr: String, duplicatedColumns: Seq[String]): Unit = {
val expectedErrorMsg = "Found duplicate column(s) in SchemaUtilsSuite: " +
duplicatedColumns.map(c => s"`${c.toLowerCase}`").mkString(", ")
val schema = StructType.fromDDL(schemaStr)
var msg = intercept[AnalysisException] {
SchemaUtils.checkSchemaColumnNameDuplication(
schema, "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive)
}.getMessage
assert(msg.contains(expectedErrorMsg))
msg = intercept[AnalysisException] {
SchemaUtils.checkColumnNameDuplication(
schema.map(_.name), "in SchemaUtilsSuite", resolver(caseSensitive))
}.getMessage
assert(msg.contains(expectedErrorMsg))
msg = intercept[AnalysisException] {
SchemaUtils.checkColumnNameDuplication(
schema.map(_.name), "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive)
}.getMessage
assert(msg.contains(expectedErrorMsg))
}

checkExceptionCases(s"$a0 INT, b INT, $a1 INT", a0 :: Nil)
checkExceptionCases(s"$a0 INT, b INT, $a1 INT, $a0 INT", a0 :: Nil)
checkExceptionCases(s"$a0 INT, $b0 INT, $a1 INT, $a0 INT, $b1 INT", b0 :: a0 :: Nil)
}
}

test("Check no exception thrown for valid schemas") {
def checkNoExceptionCases(schemaStr: String, caseSensitive: Boolean): Unit = {
val schema = StructType.fromDDL(schemaStr)
SchemaUtils.checkSchemaColumnNameDuplication(
schema, "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive)
SchemaUtils.checkColumnNameDuplication(
schema.map(_.name), "in SchemaUtilsSuite", resolver(caseSensitive))
SchemaUtils.checkColumnNameDuplication(
schema.map(_.name), "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive)
}

checkNoExceptionCases("a INT, b INT, c INT", caseSensitive = true)
checkNoExceptionCases("Aa INT, b INT, aA INT", caseSensitive = true)

checkNoExceptionCases("a INT, b INT, c INT", caseSensitive = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.command

import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.command
import java.io.File
import java.net.URI
import java.nio.file.FileSystems
import java.util.Date

import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import scala.util.Try

import org.apache.commons.lang3.StringEscapeUtils
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
Expand All @@ -42,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -202,6 +201,11 @@ case class AlterTableAddColumnsCommand(

// make sure any partition columns are at the end of the fields
val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema

SchemaUtils.checkColumnNameDuplication(
reorderedSchema.map(_.name), "in the table definition of " + table.identifier,
conf.caseSensitiveAnalysis)

catalog.alterTableSchema(
table, catalogTable.schema.copy(fields = reorderedSchema.toArray))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View}
import org.apache.spark.sql.types.MetadataBuilder
import org.apache.spark.sql.util.SchemaUtils


/**
Expand Down Expand Up @@ -355,15 +356,15 @@ object ViewHelper {
properties: Map[String, String],
session: SparkSession,
analyzedPlan: LogicalPlan): Map[String, String] = {
val queryOutput = analyzedPlan.schema.fieldNames

// Generate the query column names, throw an AnalysisException if there exists duplicate column
// names.
val queryOutput = analyzedPlan.schema.fieldNames
assert(queryOutput.distinct.size == queryOutput.size,
s"The view output ${queryOutput.mkString("(", ",", ")")} contains duplicate column name.")
SchemaUtils.checkColumnNameDuplication(
queryOutput, "in the view definition", session.sessionState.conf.resolver)

// Generate the view default database name.
val viewDefaultDatabase = session.sessionState.catalog.getCurrentDatabase

removeQueryColumnNames(properties) ++
generateViewDefaultDatabase(viewDefaultDatabase) ++
generateQueryColumnNames(queryOutput)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ case class DataSource(
lazy val providingClass: Class[_] = DataSource.lookupDataSource(className)
lazy val sourceInfo: SourceInfo = sourceSchema()
private val caseInsensitiveOptions = CaseInsensitiveMap(options)
private val equality = sparkSession.sessionState.conf.resolver

bucketSpec.map { bucket =>
SchemaUtils.checkColumnNameDuplication(
bucket.bucketColumnNames, "in the bucket definition", equality)
SchemaUtils.checkColumnNameDuplication(
bucket.sortColumnNames, "in the sort definition", equality)
}

/**
* Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer
Expand Down Expand Up @@ -132,7 +140,6 @@ case class DataSource(
// Try to infer partitioning, because no DataSource in the read path provides the partitioning
// columns properly unless it is a Hive DataSource
val resolved = tempFileIndex.partitionSchema.map { partitionField =>
val equality = sparkSession.sessionState.conf.resolver
// SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
partitionField)
Expand All @@ -146,7 +153,6 @@ case class DataSource(
inferredPartitions
} else {
val partitionFields = partitionColumns.map { partitionColumn =>
val equality = sparkSession.sessionState.conf.resolver
userSpecifiedSchema.flatMap(_.find(c => equality(c.name, partitionColumn))).orElse {
val inferredPartitions = tempFileIndex.partitionSchema
val inferredOpt = inferredPartitions.find(p => equality(p.name, partitionColumn))
Expand All @@ -172,7 +178,6 @@ case class DataSource(
}

val dataSchema = userSpecifiedSchema.map { schema =>
val equality = sparkSession.sessionState.conf.resolver
StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name))))
}.orElse {
format.inferSchema(
Expand All @@ -184,9 +189,18 @@ case class DataSource(
s"Unable to infer schema for $format. It must be specified manually.")
}

SchemaUtils.checkColumnNameDuplication(
(dataSchema ++ partitionSchema).map(_.name), "in the data schema and the partition schema",
sparkSession.sessionState.conf.caseSensitiveAnalysis)
// We just print a waring message if the data schema and partition schema have the duplicate
// columns. This is because we allow users to do so in the previous Spark releases and
// we have the existing tests for the cases (e.g., `ParquetHadoopFsRelationSuite`).
// See SPARK-18108 and SPARK-21144 for related discussions.
try {
SchemaUtils.checkColumnNameDuplication(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we put this check in the constructor of DataSource? so it works for both read nad write path

Copy link
Member Author

@maropu maropu Jun 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since dataSchema and partitionSchema is firstly fixed by this function (getOrInferFileFormatSchema), IIUC we couldn't easily put this check in the constructer. If we put the check there, we need to move some code to decide schemas inside this function into the constructor. Thought?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check also works for file-based data source, how about other data sources?

Copy link
Member Author

@maropu maropu Jun 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rechecked related code path again though, I couldn't find this issue ( dataSchema and partitionSchema has duplicate column names) in other data sources. Actually, I think this issue happens in file-based data sources only when users directly write partition directories (e.g., Seq(1, 2, 3).toDF("a").write.parquet(s"$path/a=1")).

In catalog tables, dataSchema and partitionSchema couldn't has duplicate names in write paths;


scala> sql("""CREATE TABLE t1(a INT, b INT, c INT) PARTITIONED BY (a INT)""")
org.apache.spark.sql.AnalysisException: Found duplicate column(s) in the table definition of `default`.`t1`: `a`;
  at org.apache.spark.sql.util.SchemaUtils$.checkColumnNameDuplication(SchemaUtils.scala:85)
  at org.apache.spark.sql.util.SchemaUtils$.checkSchemaColumnNameDuplication(SchemaUtils.scala:42)
  at org.apache.spark.sql.execution.datasources.PreprocessTableCreation.org$apache$spark$sql$execution$datasources$PreprocessTableCreation$$normalizeCatalogTable(rules.scala:226)

scala> sql("""CREATE TABLE t2(a INT, b INT, c INT)""")
scala> sql("""ALTER TABLE t2 ADD PARTITION (a = 1)""")
org.apache.spark.sql.AnalysisException: a is not a valid partition column in table `default`.`t2`.;
  at org.apache.spark.sql.execution.datasources.PartitioningUtils$$anonfun$7$$anonfun$9.apply(PartitioningUtils.scala:300)
  at org.apache.spark.sql.execution.datasources.PartitioningUtils$$anonfun$7$$anonfun$9.apply(PartitioningUtils.scala:300)
  at scala.Option.getOrElse(Option.scala:121)

In stream sources, schemas have no partition (so, this issue does not happen) in read paths;
https://github.com/maropu/spark/blob/ad30aded7e95bb51d2028a4a21998c72c0338b3a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala#L210
https://github.com/maropu/spark/blob/ad30aded7e95bb51d2028a4a21998c72c0338b3a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala#L255

In stream sinks, since we assume partition columns is selected from data columns (https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala#L116), this issue does not happen.

So, IMHO the duplication check in getOrInferFileFormatSchema is enough for this case.

(dataSchema ++ partitionSchema).map(_.name),
"in the data schema and the partition schema",
equality)
} catch {
case e: AnalysisException => logWarning(e.getMessage)
}

(dataSchema, partitionSchema)
}
Expand Down Expand Up @@ -391,6 +405,23 @@ case class DataSource(
s"$className is not a valid Spark SQL Data Source.")
}

relation match {
case hs: HadoopFsRelation =>
SchemaUtils.checkColumnNameDuplication(
hs.dataSchema.map(_.name),
"in the data schema",
equality)
SchemaUtils.checkColumnNameDuplication(
hs.partitionSchema.map(_.name),
"in the partition schema",
equality)
case _ =>
SchemaUtils.checkColumnNameDuplication(
relation.schema.map(_.name),
"in the data schema",
equality)
}

relation
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.io.IOException

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.SparkContext
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition}
Expand All @@ -30,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.util.SchemaUtils

/**
* A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
Expand Down Expand Up @@ -64,13 +63,10 @@ case class InsertIntoHadoopFsRelationCommand(
assert(children.length == 1)

// Most formats don't do well with duplicate columns, so lets not allow that
if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) {
val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect {
case (x, ys) if ys.length > 1 => "\"" + x + "\""
}.mkString(", ")
throw new AnalysisException(s"Duplicate column(s): $duplicateColumns found, " +
"cannot save to file.")
}
SchemaUtils.checkSchemaColumnNameDuplication(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know when we will hit this? I think it's already done before we reach here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll re-check the code path. just a sec.

Copy link
Member Author

@maropu maropu Jun 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query.schema,
s"when inserting into $outputPath",
sparkSession.sessionState.conf.caseSensitiveAnalysis)

val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options)
val fs = outputPath.getFileSystem(hadoopConf)
Expand Down
Loading