From 8fbd18d66201b2512fc4d728941ce2fc6b2400c3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 14 Nov 2016 00:59:23 -0800 Subject: [PATCH 1/2] [SPARK-18433][SQL] Improve DataSource option keys to be more case-insensitive --- .../spark/sql/catalyst/json/JSONOptions.scala | 37 +++++++++++-------- .../catalyst/util/CaseInsensitiveMap.scala | 36 ++++++++++++++++++ .../execution/datasources/DataSource.scala | 30 ++++++++------- .../datasources/csv/CSVOptions.scala | 34 +++++++++-------- .../spark/sql/execution/datasources/ddl.scala | 18 --------- .../datasources/jdbc/JDBCOptions.scala | 35 ++++++++++-------- .../datasources/parquet/ParquetOptions.scala | 10 +++-- .../streaming/FileStreamOptions.scala | 19 ++++++---- .../datasources/csv/CSVInferSchemaSuite.scala | 5 +++ .../datasources/json/JsonSuite.scala | 14 +++++++ .../spark/sql/jdbc/JDBCWriteSuite.scala | 9 +++++ .../sql/streaming/FileStreamSourceSuite.scala | 5 +++ .../spark/sql/hive/HiveExternalCatalog.scala | 2 +- .../spark/sql/hive/orc/OrcOptions.scala | 8 +++- .../spark/sql/hive/orc/OrcSourceSuite.scala | 4 ++ .../apache/spark/sql/hive/parquetSuites.scala | 8 ++++ 16 files changed, 183 insertions(+), 91 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index c45970658cf0..cd62c85a1788 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -23,7 +23,7 @@ import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.{CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} /** * Options for parsing JSON data into Spark SQL rows. @@ -34,35 +34,42 @@ private[sql] class JSONOptions( @transient private val parameters: Map[String, String]) extends Logging with Serializable { + private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + val samplingRatio = - parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + caseInsensitiveOptions.get("samplingRatio").map(_.toDouble).getOrElse(1.0) val primitivesAsString = - parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + caseInsensitiveOptions.get("primitivesAsString").map(_.toBoolean).getOrElse(false) val prefersDecimal = - parameters.get("prefersDecimal").map(_.toBoolean).getOrElse(false) + caseInsensitiveOptions.get("prefersDecimal").map(_.toBoolean).getOrElse(false) val allowComments = - parameters.get("allowComments").map(_.toBoolean).getOrElse(false) + caseInsensitiveOptions.get("allowComments").map(_.toBoolean).getOrElse(false) val allowUnquotedFieldNames = - parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) + caseInsensitiveOptions.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) val allowSingleQuotes = - parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) + caseInsensitiveOptions.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) val allowNumericLeadingZeros = - parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) + caseInsensitiveOptions.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) val allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + caseInsensitiveOptions.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = - parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) - val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord") + caseInsensitiveOptions.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean) + .getOrElse(false) + val compressionCodec = + caseInsensitiveOptions.get("compression").map(CompressionCodecs.getCodecClassName) + private val parseMode = caseInsensitiveOptions.getOrElse("mode", "PERMISSIVE") + val columnNameOfCorruptRecord = caseInsensitiveOptions.get("columnNameOfCorruptRecord") // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance( + caseInsensitiveOptions.getOrElse("dateFormat", "yyyy-MM-dd"), + Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) + caseInsensitiveOptions.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), + Locale.US) // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala new file mode 100644 index 000000000000..a7f7a8a66382 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +/** + * Builds a map in which keys are case insensitive + */ +class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] + with Serializable { + + val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) + + override def get(k: String): Option[String] = baseMap.get(k.toLowerCase) + + override def + [B1 >: String](kv: (String, B1)): Map[String, B1] = + baseMap + kv.copy(_1 = kv._1.toLowerCase) + + override def iterator: Iterator[(String, String)] = baseMap.iterator + + override def -(key: String): Map[String, String] = baseMap - key.toLowerCase +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 65422f1495f0..cfee7be1e3f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -80,13 +81,13 @@ case class DataSource( lazy val providingClass: Class[_] = DataSource.lookupDataSource(className) lazy val sourceInfo = sourceSchema() + private val caseInsensitiveOptions = new CaseInsensitiveMap(options) /** * Infer the schema of the given FileFormat, returns a pair of schema and partition column names. */ private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = { userSpecifiedSchema.map(_ -> partitionColumns).orElse { - val caseInsensitiveOptions = new CaseInsensitiveMap(options) val allPaths = caseInsensitiveOptions.get("path") val globbedPaths = allPaths.toSeq.flatMap { path => val hdfsPath = new Path(path) @@ -114,11 +115,10 @@ case class DataSource( providingClass.newInstance() match { case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( - sparkSession.sqlContext, userSpecifiedSchema, className, options) + sparkSession.sqlContext, userSpecifiedSchema, className, caseInsensitiveOptions) SourceInfo(name, schema, Nil) case format: FileFormat => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) @@ -158,10 +158,14 @@ case class DataSource( providingClass.newInstance() match { case s: StreamSourceProvider => s.createSource( - sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options) + sparkSession.sqlContext, + metadataPath, + userSpecifiedSchema, + className, + caseInsensitiveOptions) case format: FileFormat => - val path = new CaseInsensitiveMap(options).getOrElse("path", { + val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) new FileStreamSource( @@ -171,7 +175,7 @@ case class DataSource( schema = sourceInfo.schema, partitionColumns = sourceInfo.partitionColumns, metadataPath = metadataPath, - options = options) + options = caseInsensitiveOptions) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed reading") @@ -182,10 +186,9 @@ case class DataSource( def createSink(outputMode: OutputMode): Sink = { providingClass.newInstance() match { case s: StreamSinkProvider => - s.createSink(sparkSession.sqlContext, options, partitionColumns, outputMode) + s.createSink(sparkSession.sqlContext, caseInsensitiveOptions, partitionColumns, outputMode) case fileFormat: FileFormat => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) @@ -193,7 +196,7 @@ case class DataSource( throw new IllegalArgumentException( s"Data source $className does not support $outputMode output mode") } - new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, options) + new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, caseInsensitiveOptions) case _ => throw new UnsupportedOperationException( @@ -234,7 +237,6 @@ case class DataSource( * that files already exist, we don't need to check them again. */ def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = { - val caseInsensitiveOptions = new CaseInsensitiveMap(options) val relation = (providingClass.newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => @@ -274,7 +276,7 @@ case class DataSource( dataSchema = dataSchema, bucketSpec = None, format, - options)(sparkSession) + caseInsensitiveOptions)(sparkSession) // This is a non-streaming file based datasource. case (format: FileFormat, _) => @@ -358,13 +360,13 @@ case class DataSource( providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.sqlContext, mode, options, data) + dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data) case format: FileFormat => // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; // 3. It's OK that the output path doesn't exist yet; - val allPaths = paths ++ new CaseInsensitiveMap(options).get("path") + val allPaths = paths ++ caseInsensitiveOptions.get("path") val outputPath = if (allPaths.length == 1) { val path = new Path(allPaths.head) val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) @@ -391,7 +393,7 @@ case class DataSource( // TODO: Case sensitivity. val sameColumns = existingPartitionColumns.map(_.toLowerCase()) == partitionColumns.map(_.toLowerCase()) - if (existingPartitionColumns.size > 0 && !sameColumns) { + if (existingPartitionColumns.nonEmpty && !sameColumns) { throw new AnalysisException( s"""Requested partitioning does not match existing partitioning. |Existing partitioning columns: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 5903729c11fc..9831949a8329 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -23,13 +23,15 @@ import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.{CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} private[csv] class CSVOptions(@transient private val parameters: Map[String, String]) extends Logging with Serializable { + private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + private def getChar(paramName: String, default: Char): Char = { - val paramValue = parameters.get(paramName) + val paramValue = caseInsensitiveOptions.get(paramName) paramValue match { case None => default case Some(null) => default @@ -40,7 +42,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str } private def getInt(paramName: String, default: Int): Int = { - val paramValue = parameters.get(paramName) + val paramValue = caseInsensitiveOptions.get(paramName) paramValue match { case None => default case Some(null) => default @@ -54,7 +56,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str } private def getBool(paramName: String, default: Boolean = false): Boolean = { - val param = parameters.getOrElse(paramName, default.toString) + val param = caseInsensitiveOptions.getOrElse(paramName, default.toString) if (param == null) { default } else if (param.toLowerCase == "true") { @@ -67,10 +69,10 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str } val delimiter = CSVTypeCast.toChar( - parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val charset = parameters.getOrElse("encoding", - parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) + caseInsensitiveOptions.getOrElse("sep", caseInsensitiveOptions.getOrElse("delimiter", ","))) + private val parseMode = caseInsensitiveOptions.getOrElse("mode", "PERMISSIVE") + val charset = caseInsensitiveOptions.getOrElse("encoding", + caseInsensitiveOptions.getOrElse("charset", StandardCharsets.UTF_8.name())) val quote = getChar("quote", '\"') val escape = getChar("escape", '\\') @@ -90,26 +92,28 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str val dropMalformed = ParseModes.isDropMalformedMode(parseMode) val permissive = ParseModes.isPermissiveMode(parseMode) - val nullValue = parameters.getOrElse("nullValue", "") + val nullValue = caseInsensitiveOptions.getOrElse("nullValue", "") - val nanValue = parameters.getOrElse("nanValue", "NaN") + val nanValue = caseInsensitiveOptions.getOrElse("nanValue", "NaN") - val positiveInf = parameters.getOrElse("positiveInf", "Inf") - val negativeInf = parameters.getOrElse("negativeInf", "-Inf") + val positiveInf = caseInsensitiveOptions.getOrElse("positiveInf", "Inf") + val negativeInf = caseInsensitiveOptions.getOrElse("negativeInf", "-Inf") val compressionCodec: Option[String] = { - val name = parameters.get("compression").orElse(parameters.get("codec")) + val name = caseInsensitiveOptions.get("compression").orElse(caseInsensitiveOptions.get("codec")) name.map(CompressionCodecs.getCodecClassName) } // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance( + caseInsensitiveOptions.getOrElse("dateFormat", "yyyy-MM-dd"), + Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) + caseInsensitiveOptions.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 59fb48ffea59..fa8dfa9640d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -96,21 +96,3 @@ case class RefreshResource(path: String) Seq.empty[Row] } } - -/** - * Builds a map in which keys are case insensitive - */ -class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] - with Serializable { - - val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) - - override def get(k: String): Option[String] = baseMap.get(k.toLowerCase) - - override def + [B1 >: String](kv: (String, B1)): Map[String, B1] = - baseMap + kv.copy(_1 = kv._1.toLowerCase) - - override def iterator: Iterator[(String, String)] = baseMap.iterator - - override def -(key: String): Map[String, String] = baseMap - key.toLowerCase -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index fcd7409159de..69466fa034aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -22,6 +22,8 @@ import java.util.Properties import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + /** * Options for the JDBC data source. */ @@ -31,6 +33,8 @@ class JDBCOptions( import JDBCOptions._ + private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + def this(url: String, table: String, parameters: Map[String, String]) = { this(parameters ++ Map( JDBCOptions.JDBC_URL -> url, @@ -40,7 +44,7 @@ class JDBCOptions( val asConnectionProperties: Properties = { val properties = new Properties() // We should avoid to pass the options into properties. See SPARK-17776. - parameters.filterKeys(!jdbcOptionNames.contains(_)) + caseInsensitiveOptions.filterKeys(!jdbcOptionNames.contains(_)) .foreach { case (k, v) => properties.setProperty(k, v) } properties } @@ -48,18 +52,19 @@ class JDBCOptions( // ------------------------------------------------------------ // Required parameters // ------------------------------------------------------------ - require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") - require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") + require(caseInsensitiveOptions.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") + require(caseInsensitiveOptions.isDefinedAt(JDBC_TABLE_NAME), + s"Option '$JDBC_TABLE_NAME' is required.") // a JDBC URL - val url = parameters(JDBC_URL) + val url = caseInsensitiveOptions(JDBC_URL) // name of table - val table = parameters(JDBC_TABLE_NAME) + val table = caseInsensitiveOptions(JDBC_TABLE_NAME) // ------------------------------------------------------------ // Optional parameters // ------------------------------------------------------------ val driverClass = { - val userSpecifiedDriverClass = parameters.get(JDBC_DRIVER_CLASS) + val userSpecifiedDriverClass = caseInsensitiveOptions.get(JDBC_DRIVER_CLASS) userSpecifiedDriverClass.foreach(DriverRegistry.register) // Performing this part of the logic on the driver guards against the corner-case where the @@ -74,19 +79,19 @@ class JDBCOptions( // Optional parameters only for reading // ------------------------------------------------------------ // the column used to partition - val partitionColumn = parameters.getOrElse(JDBC_PARTITION_COLUMN, null) + val partitionColumn = caseInsensitiveOptions.getOrElse(JDBC_PARTITION_COLUMN, null) // the lower bound of partition column - val lowerBound = parameters.getOrElse(JDBC_LOWER_BOUND, null) + val lowerBound = caseInsensitiveOptions.getOrElse(JDBC_LOWER_BOUND, null) // the upper bound of the partition column - val upperBound = parameters.getOrElse(JDBC_UPPER_BOUND, null) + val upperBound = caseInsensitiveOptions.getOrElse(JDBC_UPPER_BOUND, null) // the number of partitions - val numPartitions = parameters.getOrElse(JDBC_NUM_PARTITIONS, null) + val numPartitions = caseInsensitiveOptions.getOrElse(JDBC_NUM_PARTITIONS, null) require(partitionColumn == null || (lowerBound != null && upperBound != null && numPartitions != null), s"If '$JDBC_PARTITION_COLUMN' is specified then '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND'," + s" and '$JDBC_NUM_PARTITIONS' are required.") val fetchSize = { - val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt + val size = caseInsensitiveOptions.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt require(size >= 0, s"Invalid value `${size.toString}` for parameter " + s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " + @@ -98,20 +103,20 @@ class JDBCOptions( // Optional parameters only for writing // ------------------------------------------------------------ // if to truncate the table from the JDBC database - val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean + val isTruncate = caseInsensitiveOptions.getOrElse(JDBC_TRUNCATE, "false").toBoolean // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options - val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") + val createTableOptions = caseInsensitiveOptions.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") val batchSize = { - val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt + val size = caseInsensitiveOptions.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt require(size >= 1, s"Invalid value `${size.toString}` for parameter " + s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.") size } val isolationLevel = - parameters.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match { + caseInsensitiveOptions.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match { case "NONE" => Connection.TRANSACTION_NONE case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index d0fd23605bea..aa9b729b2bd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,24 +19,28 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[parquet] class ParquetOptions( +private[sql] class ParquetOptions( @transient private val parameters: Map[String, String], @transient private val sqlConf: SQLConf) extends Serializable { import ParquetOptions._ + private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + /** * Compression codec to use. By default use the value specified in SQLConf. * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase + val codecName = + caseInsensitiveOptions.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) throw new IllegalArgumentException(s"Codec [$codecName] " + @@ -49,7 +53,7 @@ private[parquet] class ParquetOptions( * Whether it merges schemas or not. When the given Parquet files have different schemas, * the schemas can be merged. By default use the value specified in SQLConf. */ - val mergeSchema: Boolean = parameters + val mergeSchema: Boolean = caseInsensitiveOptions .get(MERGE_SCHEMA) .map(_.toBoolean) .getOrElse(sqlConf.isParquetSchemaMergingEnabled) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index 3efc20c1d662..644114fbddf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import scala.util.Try import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.util.Utils /** @@ -28,11 +28,14 @@ import org.apache.spark.util.Utils */ class FileStreamOptions(parameters: Map[String, String]) extends Logging { - val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str => - Try(str.toInt).toOption.filter(_ > 0).getOrElse { - throw new IllegalArgumentException( - s"Invalid value '$str' for option 'maxFilesPerTrigger', must be a positive integer") - } + private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + + val maxFilesPerTrigger: Option[Int] = caseInsensitiveOptions.get("maxFilesPerTrigger").map { + str => + Try(str.toInt).toOption.filter(_ > 0).getOrElse { + throw new IllegalArgumentException( + s"Invalid value '$str' for option 'maxFilesPerTrigger', must be a positive integer") + } } /** @@ -46,9 +49,9 @@ class FileStreamOptions(parameters: Map[String, String]) extends Logging { * Default to a week. */ val maxFileAgeMs: Long = - Utils.timeStringAsMs(parameters.getOrElse("maxFileAge", "7d")) + Utils.timeStringAsMs(caseInsensitiveOptions.getOrElse("maxFileAge", "7d")) /** Options as specified by the user, in a case-insensitive map, without "path" set. */ val optionMapWithoutPath: Map[String, String] = - new CaseInsensitiveMap(parameters).filterKeys(_ != "path") + new CaseInsensitiveMap(caseInsensitiveOptions).filterKeys(_ != "path") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 5e00f669b859..93f752d107ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -109,4 +109,9 @@ class CSVInferSchemaSuite extends SparkFunSuite { val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) assert(mergedNullTypes.deep == Array(NullType).deep) } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm")) + assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 456052f79afc..27438b4a4588 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1749,4 +1749,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer(stringTimestampsWithFormat, expectedStringDatesWithFormat) } } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + val records = sparkContext + .parallelize("""{"a": 3, "b": 1.1}""" :: """{"a": 3.1, "b": 0.000001}""" :: Nil) + + val schema = StructType( + StructField("a", DecimalType(21, 1), true) :: + StructField("b", DecimalType(7, 6), true) :: Nil) + + val df1 = spark.read.option("prefersDecimal", "true").json(records) + assert(df1.schema == schema) + val df2 = spark.read.option("PREfersdecimaL", "true").json(records) + assert(df2.schema == schema) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 96540ec92da7..e3d3c6c3a887 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -303,4 +303,13 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + " and 'numPartitions' are required.")) } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + df.write.format("jdbc") + .option("Url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index fab7642994ff..b365af76c379 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1004,6 +1004,11 @@ class FileStreamSourceSuite extends FileStreamSourceTest { ) } } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + val options = new FileStreamOptions(Map("maxfilespertrigger" -> "1")) + assert(options.maxFilesPerTrigger == Some(1)) + } } class FileStreamSourceStressTestSuite extends FileStreamSourceTest { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 42ce1a88a2b6..cbd00da81cfc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -35,8 +35,8 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} -import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index c2a126d3bf9c..2d1fff54699f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.orc +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + /** * Options for the ORC data source. */ @@ -25,6 +27,8 @@ private[orc] class OrcOptions(@transient private val parameters: Map[String, Str import OrcOptions._ + private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + /** * Compression codec to use. By default snappy compression. * Acceptable values are defined in [[shortOrcCompressionCodecNames]]. @@ -33,8 +37,8 @@ private[orc] class OrcOptions(@transient private val parameters: Map[String, Str // `orc.compress` is a ORC configuration. So, here we respect this as an option but // `compression` has higher precedence than `orc.compress`. It means if both are set, // we will use `compression`. - val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION) - val codecName = parameters + val orcCompressionConf = caseInsensitiveOptions.get(OrcRelation.ORC_COMPRESSION) + val codecName = caseInsensitiveOptions .get("compression") .orElse(orcCompressionConf) .getOrElse("snappy").toLowerCase diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 0f37cd7bf365..12f948041a8a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -146,6 +146,10 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA sql("DROP TABLE IF EXISTS orcNullValues") } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + assert(new OrcOptions(Map("Orc.Compress" -> "NONE")).compressionCodec == "NONE") + } } class OrcSourceSuite extends OrcSuite { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 3644ff952eb0..172d0c3723db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExecutedCommandExec import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoDataSourceCommand, InsertIntoHadoopFsRelationCommand, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.hive.execution.HiveTableScanExec import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -809,6 +810,13 @@ class ParquetSourceSuite extends ParquetPartitioningTest { checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) assert(df4.columns === Array("str", "max_int")) } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val option = new ParquetOptions(Map("Compression" -> "uncompressed"), spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + } } /** From 6570665c8a2458e31ecb3f1262e64970eb686387 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 14 Nov 2016 20:59:09 -0800 Subject: [PATCH 2/2] Replace the type of `parameters` with `CaseInsensitiveMap`. --- .../spark/sql/catalyst/json/JSONOptions.scala | 37 ++++++++---------- .../spark/sql/execution/command/ddl.scala | 2 +- .../datasources/csv/CSVOptions.scala | 36 ++++++++--------- .../datasources/jdbc/JDBCOptions.scala | 39 +++++++++---------- .../datasources/parquet/ParquetOptions.scala | 12 +++--- .../streaming/FileStreamOptions.scala | 19 +++++---- .../datasources/json/JsonSuite.scala | 5 ++- .../datasources/parquet/ParquetIOSuite.scala | 7 ++++ .../spark/sql/hive/orc/OrcOptions.scala | 8 ++-- .../apache/spark/sql/hive/parquetSuites.scala | 7 ---- 10 files changed, 82 insertions(+), 90 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index cd62c85a1788..38e191bbbad6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -31,45 +31,40 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient private val parameters: Map[String, String]) + @transient private val parameters: CaseInsensitiveMap) extends Logging with Serializable { - private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters)) val samplingRatio = - caseInsensitiveOptions.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) val primitivesAsString = - caseInsensitiveOptions.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) val prefersDecimal = - caseInsensitiveOptions.get("prefersDecimal").map(_.toBoolean).getOrElse(false) + parameters.get("prefersDecimal").map(_.toBoolean).getOrElse(false) val allowComments = - caseInsensitiveOptions.get("allowComments").map(_.toBoolean).getOrElse(false) + parameters.get("allowComments").map(_.toBoolean).getOrElse(false) val allowUnquotedFieldNames = - caseInsensitiveOptions.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) + parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) val allowSingleQuotes = - caseInsensitiveOptions.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) + parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) val allowNumericLeadingZeros = - caseInsensitiveOptions.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) + parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) val allowNonNumericNumbers = - caseInsensitiveOptions.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = - caseInsensitiveOptions.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean) - .getOrElse(false) - val compressionCodec = - caseInsensitiveOptions.get("compression").map(CompressionCodecs.getCodecClassName) - private val parseMode = caseInsensitiveOptions.getOrElse("mode", "PERMISSIVE") - val columnNameOfCorruptRecord = caseInsensitiveOptions.get("columnNameOfCorruptRecord") + parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) + private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord") // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance( - caseInsensitiveOptions.getOrElse("dateFormat", "yyyy-MM-dd"), - Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - caseInsensitiveOptions.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), - Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 6c1c398940d0..588aa05c37b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryComparison} import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, PredicateHelper} -import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 9831949a8329..21e50307b5ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -25,13 +25,13 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} -private[csv] class CSVOptions(@transient private val parameters: Map[String, String]) +private[csv] class CSVOptions(@transient private val parameters: CaseInsensitiveMap) extends Logging with Serializable { - private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters)) private def getChar(paramName: String, default: Char): Char = { - val paramValue = caseInsensitiveOptions.get(paramName) + val paramValue = parameters.get(paramName) paramValue match { case None => default case Some(null) => default @@ -42,7 +42,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str } private def getInt(paramName: String, default: Int): Int = { - val paramValue = caseInsensitiveOptions.get(paramName) + val paramValue = parameters.get(paramName) paramValue match { case None => default case Some(null) => default @@ -56,7 +56,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str } private def getBool(paramName: String, default: Boolean = false): Boolean = { - val param = caseInsensitiveOptions.getOrElse(paramName, default.toString) + val param = parameters.getOrElse(paramName, default.toString) if (param == null) { default } else if (param.toLowerCase == "true") { @@ -69,10 +69,10 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str } val delimiter = CSVTypeCast.toChar( - caseInsensitiveOptions.getOrElse("sep", caseInsensitiveOptions.getOrElse("delimiter", ","))) - private val parseMode = caseInsensitiveOptions.getOrElse("mode", "PERMISSIVE") - val charset = caseInsensitiveOptions.getOrElse("encoding", - caseInsensitiveOptions.getOrElse("charset", StandardCharsets.UTF_8.name())) + parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) + private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val charset = parameters.getOrElse("encoding", + parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) val quote = getChar("quote", '\"') val escape = getChar("escape", '\\') @@ -92,28 +92,26 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str val dropMalformed = ParseModes.isDropMalformedMode(parseMode) val permissive = ParseModes.isPermissiveMode(parseMode) - val nullValue = caseInsensitiveOptions.getOrElse("nullValue", "") + val nullValue = parameters.getOrElse("nullValue", "") - val nanValue = caseInsensitiveOptions.getOrElse("nanValue", "NaN") + val nanValue = parameters.getOrElse("nanValue", "NaN") - val positiveInf = caseInsensitiveOptions.getOrElse("positiveInf", "Inf") - val negativeInf = caseInsensitiveOptions.getOrElse("negativeInf", "-Inf") + val positiveInf = parameters.getOrElse("positiveInf", "Inf") + val negativeInf = parameters.getOrElse("negativeInf", "-Inf") val compressionCodec: Option[String] = { - val name = caseInsensitiveOptions.get("compression").orElse(caseInsensitiveOptions.get("codec")) + val name = parameters.get("compression").orElse(parameters.get("codec")) name.map(CompressionCodecs.getCodecClassName) } // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance( - caseInsensitiveOptions.getOrElse("dateFormat", "yyyy-MM-dd"), - Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - caseInsensitiveOptions.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) val maxColumns = getInt("maxColumns", 20480) @@ -132,7 +130,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str object CSVOptions { - def apply(): CSVOptions = new CSVOptions(Map.empty) + def apply(): CSVOptions = new CSVOptions(new CaseInsensitiveMap(Map.empty)) def apply(paramName: String, paramValue: String): CSVOptions = { new CSVOptions(Map(paramName -> paramValue)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 69466fa034aa..7f419b5788c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -28,23 +28,23 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap * Options for the JDBC data source. */ class JDBCOptions( - @transient private val parameters: Map[String, String]) + @transient private val parameters: CaseInsensitiveMap) extends Serializable { import JDBCOptions._ - private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters)) def this(url: String, table: String, parameters: Map[String, String]) = { - this(parameters ++ Map( + this(new CaseInsensitiveMap(parameters ++ Map( JDBCOptions.JDBC_URL -> url, - JDBCOptions.JDBC_TABLE_NAME -> table)) + JDBCOptions.JDBC_TABLE_NAME -> table))) } val asConnectionProperties: Properties = { val properties = new Properties() // We should avoid to pass the options into properties. See SPARK-17776. - caseInsensitiveOptions.filterKeys(!jdbcOptionNames.contains(_)) + parameters.filterKeys(!jdbcOptionNames.contains(_)) .foreach { case (k, v) => properties.setProperty(k, v) } properties } @@ -52,19 +52,18 @@ class JDBCOptions( // ------------------------------------------------------------ // Required parameters // ------------------------------------------------------------ - require(caseInsensitiveOptions.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") - require(caseInsensitiveOptions.isDefinedAt(JDBC_TABLE_NAME), - s"Option '$JDBC_TABLE_NAME' is required.") + require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") + require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") // a JDBC URL - val url = caseInsensitiveOptions(JDBC_URL) + val url = parameters(JDBC_URL) // name of table - val table = caseInsensitiveOptions(JDBC_TABLE_NAME) + val table = parameters(JDBC_TABLE_NAME) // ------------------------------------------------------------ // Optional parameters // ------------------------------------------------------------ val driverClass = { - val userSpecifiedDriverClass = caseInsensitiveOptions.get(JDBC_DRIVER_CLASS) + val userSpecifiedDriverClass = parameters.get(JDBC_DRIVER_CLASS) userSpecifiedDriverClass.foreach(DriverRegistry.register) // Performing this part of the logic on the driver guards against the corner-case where the @@ -79,19 +78,19 @@ class JDBCOptions( // Optional parameters only for reading // ------------------------------------------------------------ // the column used to partition - val partitionColumn = caseInsensitiveOptions.getOrElse(JDBC_PARTITION_COLUMN, null) + val partitionColumn = parameters.getOrElse(JDBC_PARTITION_COLUMN, null) // the lower bound of partition column - val lowerBound = caseInsensitiveOptions.getOrElse(JDBC_LOWER_BOUND, null) + val lowerBound = parameters.getOrElse(JDBC_LOWER_BOUND, null) // the upper bound of the partition column - val upperBound = caseInsensitiveOptions.getOrElse(JDBC_UPPER_BOUND, null) + val upperBound = parameters.getOrElse(JDBC_UPPER_BOUND, null) // the number of partitions - val numPartitions = caseInsensitiveOptions.getOrElse(JDBC_NUM_PARTITIONS, null) + val numPartitions = parameters.getOrElse(JDBC_NUM_PARTITIONS, null) require(partitionColumn == null || (lowerBound != null && upperBound != null && numPartitions != null), s"If '$JDBC_PARTITION_COLUMN' is specified then '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND'," + s" and '$JDBC_NUM_PARTITIONS' are required.") val fetchSize = { - val size = caseInsensitiveOptions.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt + val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt require(size >= 0, s"Invalid value `${size.toString}` for parameter " + s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " + @@ -103,20 +102,20 @@ class JDBCOptions( // Optional parameters only for writing // ------------------------------------------------------------ // if to truncate the table from the JDBC database - val isTruncate = caseInsensitiveOptions.getOrElse(JDBC_TRUNCATE, "false").toBoolean + val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options - val createTableOptions = caseInsensitiveOptions.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") + val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") val batchSize = { - val size = caseInsensitiveOptions.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt + val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt require(size >= 1, s"Invalid value `${size.toString}` for parameter " + s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.") size } val isolationLevel = - caseInsensitiveOptions.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match { + parameters.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match { case "NONE" => Connection.TRANSACTION_NONE case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index aa9b729b2bd5..a81a95d51085 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -25,22 +25,22 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[sql] class ParquetOptions( - @transient private val parameters: Map[String, String], +private[parquet] class ParquetOptions( + @transient private val parameters: CaseInsensitiveMap, @transient private val sqlConf: SQLConf) extends Serializable { import ParquetOptions._ - private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + def this(parameters: Map[String, String], sqlConf: SQLConf) = + this(new CaseInsensitiveMap(parameters), sqlConf) /** * Compression codec to use. By default use the value specified in SQLConf. * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = - caseInsensitiveOptions.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase + val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) throw new IllegalArgumentException(s"Codec [$codecName] " + @@ -53,7 +53,7 @@ private[sql] class ParquetOptions( * Whether it merges schemas or not. When the given Parquet files have different schemas, * the schemas can be merged. By default use the value specified in SQLConf. */ - val mergeSchema: Boolean = caseInsensitiveOptions + val mergeSchema: Boolean = parameters .get(MERGE_SCHEMA) .map(_.toBoolean) .getOrElse(sqlConf.isParquetSchemaMergingEnabled) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index 644114fbddf0..fdea65cb10ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -26,16 +26,15 @@ import org.apache.spark.util.Utils /** * User specified options for file streams. */ -class FileStreamOptions(parameters: Map[String, String]) extends Logging { +class FileStreamOptions(parameters: CaseInsensitiveMap) extends Logging { - private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters)) - val maxFilesPerTrigger: Option[Int] = caseInsensitiveOptions.get("maxFilesPerTrigger").map { - str => - Try(str.toInt).toOption.filter(_ > 0).getOrElse { - throw new IllegalArgumentException( - s"Invalid value '$str' for option 'maxFilesPerTrigger', must be a positive integer") - } + val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str => + Try(str.toInt).toOption.filter(_ > 0).getOrElse { + throw new IllegalArgumentException( + s"Invalid value '$str' for option 'maxFilesPerTrigger', must be a positive integer") + } } /** @@ -49,9 +48,9 @@ class FileStreamOptions(parameters: Map[String, String]) extends Logging { * Default to a week. */ val maxFileAgeMs: Long = - Utils.timeStringAsMs(caseInsensitiveOptions.getOrElse("maxFileAge", "7d")) + Utils.timeStringAsMs(parameters.getOrElse("maxFileAge", "7d")) /** Options as specified by the user, in a case-insensitive map, without "path" set. */ val optionMapWithoutPath: Map[String, String] = - new CaseInsensitiveMap(caseInsensitiveOptions).filterKeys(_ != "path") + parameters.filterKeys(_ != "path") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 27438b4a4588..598e44ec8c19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1366,7 +1366,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map())) + val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String])) assert(StructType(Seq()) === emptySchema) } @@ -1390,7 +1390,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema.infer(emptyRecords, "", new JSONOptions(Map())) + val emptySchema = InferSchema.infer( + emptyRecords, "", new JSONOptions(Map.empty[String, String])) assert(StructType(Seq()) === emptySchema) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 580eade4b141..acdadb3103c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -736,6 +736,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val option = new ParquetOptions(Map("Compression" -> "uncompressed"), spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index 2d1fff54699f..ac587ab99ae2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** * Options for the ORC data source. */ -private[orc] class OrcOptions(@transient private val parameters: Map[String, String]) +private[orc] class OrcOptions(@transient private val parameters: CaseInsensitiveMap) extends Serializable { import OrcOptions._ - private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters)) /** * Compression codec to use. By default snappy compression. @@ -37,8 +37,8 @@ private[orc] class OrcOptions(@transient private val parameters: Map[String, Str // `orc.compress` is a ORC configuration. So, here we respect this as an option but // `compression` has higher precedence than `orc.compress`. It means if both are set, // we will use `compression`. - val orcCompressionConf = caseInsensitiveOptions.get(OrcRelation.ORC_COMPRESSION) - val codecName = caseInsensitiveOptions + val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION) + val codecName = parameters .get("compression") .orElse(orcCompressionConf) .getOrElse("snappy").toLowerCase diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 172d0c3723db..2ce60fe58921 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -810,13 +810,6 @@ class ParquetSourceSuite extends ParquetPartitioningTest { checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) assert(df4.columns === Array("str", "max_int")) } - - test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { - val option = new ParquetOptions(Map("Compression" -> "uncompressed"), spark.sessionState.conf) - assert(option.compressionCodecClassName == "UNCOMPRESSED") - } - } } /**