Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@ import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.commons.lang3.time.FastDateFormat

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.{CompressionCodecs, ParseModes}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}

/**
* Options for parsing JSON data into Spark SQL rows.
*
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
*/
private[sql] class JSONOptions(
@transient private val parameters: Map[String, String])
@transient private val parameters: CaseInsensitiveMap)
extends Logging with Serializable {

def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))

val samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
val primitivesAsString =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

/**
* Builds a map in which keys are case insensitive
*/
class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
with Serializable {

val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase))

override def get(k: String): Option[String] = baseMap.get(k.toLowerCase)

override def + [B1 >: String](kv: (String, B1)): Map[String, B1] =
baseMap + kv.copy(_1 = kv._1.toLowerCase)

override def iterator: Iterator[(String, String)] = baseMap.iterator

override def -(key: String): Map[String, String] = baseMap - key.toLowerCase
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryComparison}
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, PredicateHelper}
import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
Expand Down Expand Up @@ -80,13 +81,13 @@ case class DataSource(

lazy val providingClass: Class[_] = DataSource.lookupDataSource(className)
lazy val sourceInfo = sourceSchema()
private val caseInsensitiveOptions = new CaseInsensitiveMap(options)

/**
* Infer the schema of the given FileFormat, returns a pair of schema and partition column names.
*/
private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = {
userSpecifiedSchema.map(_ -> partitionColumns).orElse {
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val allPaths = caseInsensitiveOptions.get("path")
val globbedPaths = allPaths.toSeq.flatMap { path =>
val hdfsPath = new Path(path)
Expand Down Expand Up @@ -114,11 +115,10 @@ case class DataSource(
providingClass.newInstance() match {
case s: StreamSourceProvider =>
val (name, schema) = s.sourceSchema(
sparkSession.sqlContext, userSpecifiedSchema, className, options)
sparkSession.sqlContext, userSpecifiedSchema, className, caseInsensitiveOptions)
SourceInfo(name, schema, Nil)

case format: FileFormat =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val path = caseInsensitiveOptions.getOrElse("path", {
throw new IllegalArgumentException("'path' is not specified")
})
Expand Down Expand Up @@ -158,10 +158,14 @@ case class DataSource(
providingClass.newInstance() match {
case s: StreamSourceProvider =>
s.createSource(
sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options)
sparkSession.sqlContext,
metadataPath,
userSpecifiedSchema,
className,
caseInsensitiveOptions)

case format: FileFormat =>
val path = new CaseInsensitiveMap(options).getOrElse("path", {
val path = caseInsensitiveOptions.getOrElse("path", {
throw new IllegalArgumentException("'path' is not specified")
})
new FileStreamSource(
Expand All @@ -171,7 +175,7 @@ case class DataSource(
schema = sourceInfo.schema,
partitionColumns = sourceInfo.partitionColumns,
metadataPath = metadataPath,
options = options)
options = caseInsensitiveOptions)
case _ =>
throw new UnsupportedOperationException(
s"Data source $className does not support streamed reading")
Expand All @@ -182,18 +186,17 @@ case class DataSource(
def createSink(outputMode: OutputMode): Sink = {
providingClass.newInstance() match {
case s: StreamSinkProvider =>
s.createSink(sparkSession.sqlContext, options, partitionColumns, outputMode)
s.createSink(sparkSession.sqlContext, caseInsensitiveOptions, partitionColumns, outputMode)

case fileFormat: FileFormat =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val path = caseInsensitiveOptions.getOrElse("path", {
throw new IllegalArgumentException("'path' is not specified")
})
if (outputMode != OutputMode.Append) {
throw new IllegalArgumentException(
s"Data source $className does not support $outputMode output mode")
}
new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, options)
new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, caseInsensitiveOptions)

case _ =>
throw new UnsupportedOperationException(
Expand Down Expand Up @@ -234,7 +237,6 @@ case class DataSource(
* that files already exist, we don't need to check them again.
*/
def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = {
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
Expand Down Expand Up @@ -274,7 +276,7 @@ case class DataSource(
dataSchema = dataSchema,
bucketSpec = None,
format,
options)(sparkSession)
caseInsensitiveOptions)(sparkSession)

// This is a non-streaming file based datasource.
case (format: FileFormat, _) =>
Expand Down Expand Up @@ -358,13 +360,13 @@ case class DataSource(

providingClass.newInstance() match {
case dataSource: CreatableRelationProvider =>
dataSource.createRelation(sparkSession.sqlContext, mode, options, data)
dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data)
case format: FileFormat =>
// Don't glob path for the write path. The contracts here are:
// 1. Only one output path can be specified on the write path;
// 2. Output path must be a legal HDFS style file system path;
// 3. It's OK that the output path doesn't exist yet;
val allPaths = paths ++ new CaseInsensitiveMap(options).get("path")
val allPaths = paths ++ caseInsensitiveOptions.get("path")
val outputPath = if (allPaths.length == 1) {
val path = new Path(allPaths.head)
val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
Expand All @@ -391,7 +393,7 @@ case class DataSource(
// TODO: Case sensitivity.
val sameColumns =
existingPartitionColumns.map(_.toLowerCase()) == partitionColumns.map(_.toLowerCase())
if (existingPartitionColumns.size > 0 && !sameColumns) {
if (existingPartitionColumns.nonEmpty && !sameColumns) {
throw new AnalysisException(
s"""Requested partitioning does not match existing partitioning.
|Existing partitioning columns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ import java.util.Locale
import org.apache.commons.lang3.time.FastDateFormat

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.{CompressionCodecs, ParseModes}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}

private[csv] class CSVOptions(@transient private val parameters: Map[String, String])
private[csv] class CSVOptions(@transient private val parameters: CaseInsensitiveMap)
extends Logging with Serializable {

def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))

private def getChar(paramName: String, default: Char): Char = {
val paramValue = parameters.get(paramName)
paramValue match {
Expand Down Expand Up @@ -128,7 +130,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str

object CSVOptions {

def apply(): CSVOptions = new CSVOptions(Map.empty)
def apply(): CSVOptions = new CSVOptions(new CaseInsensitiveMap(Map.empty))

def apply(paramName: String, paramValue: String): CSVOptions = {
new CSVOptions(Map(paramName -> paramValue))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,3 @@ case class RefreshResource(path: String)
Seq.empty[Row]
}
}

/**
* Builds a map in which keys are case insensitive
*/
class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
with Serializable {

val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase))

override def get(k: String): Option[String] = baseMap.get(k.toLowerCase)

override def + [B1 >: String](kv: (String, B1)): Map[String, B1] =
baseMap + kv.copy(_1 = kv._1.toLowerCase)

override def iterator: Iterator[(String, String)] = baseMap.iterator

override def -(key: String): Map[String, String] = baseMap - key.toLowerCase
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,23 @@ import java.util.Properties

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap

/**
* Options for the JDBC data source.
*/
class JDBCOptions(
@transient private val parameters: Map[String, String])
@transient private val parameters: CaseInsensitiveMap)
extends Serializable {

import JDBCOptions._

def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))

def this(url: String, table: String, parameters: Map[String, String]) = {
this(parameters ++ Map(
this(new CaseInsensitiveMap(parameters ++ Map(
JDBCOptions.JDBC_URL -> url,
JDBCOptions.JDBC_TABLE_NAME -> table))
JDBCOptions.JDBC_TABLE_NAME -> table)))
}

val asConnectionProperties: Properties = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,22 @@ package org.apache.spark.sql.execution.datasources.parquet

import org.apache.parquet.hadoop.metadata.CompressionCodecName

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf

/**
* Options for the Parquet data source.
*/
private[parquet] class ParquetOptions(
@transient private val parameters: Map[String, String],
@transient private val parameters: CaseInsensitiveMap,
@transient private val sqlConf: SQLConf)
extends Serializable {

import ParquetOptions._

def this(parameters: Map[String, String], sqlConf: SQLConf) =
this(new CaseInsensitiveMap(parameters), sqlConf)

/**
* Compression codec to use. By default use the value specified in SQLConf.
* Acceptable values are defined in [[shortParquetCompressionCodecNames]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.streaming
import scala.util.Try

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.util.Utils

/**
* User specified options for file streams.
*/
class FileStreamOptions(parameters: Map[String, String]) extends Logging {
class FileStreamOptions(parameters: CaseInsensitiveMap) extends Logging {

def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))

val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str =>
Try(str.toInt).toOption.filter(_ > 0).getOrElse {
Expand All @@ -50,5 +52,5 @@ class FileStreamOptions(parameters: Map[String, String]) extends Logging {

/** Options as specified by the user, in a case-insensitive map, without "path" set. */
val optionMapWithoutPath: Map[String, String] =
new CaseInsensitiveMap(parameters).filterKeys(_ != "path")
parameters.filterKeys(_ != "path")
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,9 @@ class CSVInferSchemaSuite extends SparkFunSuite {
val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType))
assert(mergedNullTypes.deep == Array(NullType).deep)
}

test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"))
assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {

test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
// This is really a test that it doesn't throw an exception
val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map()))
val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary change?

Copy link
Member Author

@dongjoon-hyun dongjoon-hyun Nov 16, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we expect Map[String,String], but Map() or Map.empty is Map[Nothing,Nothing].
So, there was some compilation issue to find constructor of JSONOptions.

assert(StructType(Seq()) === emptySchema)
}

Expand All @@ -1390,7 +1390,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}

test("SPARK-8093 Erase empty structs") {
val emptySchema = InferSchema.infer(emptyRecords, "", new JSONOptions(Map()))
val emptySchema = InferSchema.infer(
emptyRecords, "", new JSONOptions(Map.empty[String, String]))
assert(StructType(Seq()) === emptySchema)
}

Expand Down Expand Up @@ -1749,4 +1750,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
checkAnswer(stringTimestampsWithFormat, expectedStringDatesWithFormat)
}
}

test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
val records = sparkContext
.parallelize("""{"a": 3, "b": 1.1}""" :: """{"a": 3.1, "b": 0.000001}""" :: Nil)

val schema = StructType(
StructField("a", DecimalType(21, 1), true) ::
StructField("b", DecimalType(7, 6), true) :: Nil)

val df1 = spark.read.option("prefersDecimal", "true").json(records)
assert(df1.schema == schema)
val df2 = spark.read.option("PREfersdecimaL", "true").json(records)
assert(df2.schema == schema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
}
}
}

test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") {
val option = new ParquetOptions(Map("Compression" -> "uncompressed"), spark.sessionState.conf)
assert(option.compressionCodecClassName == "UNCOMPRESSED")
}
}
}

class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,4 +303,13 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," +
" and 'numPartitions' are required."))
}

test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
df.write.format("jdbc")
.option("Url", url1)
.option("dbtable", "TEST.SAVETEST")
.options(properties.asScala)
.save()
}
}
Loading