diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c9c342df82c9..6e3436c429d7 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ +import scala.collection.immutable import scala.collection.mutable.LinkedHashSet import org.apache.avro.{Schema, SchemaNormalization} @@ -384,6 +385,10 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria .map { case (k, v) => (k.substring(prefix.length), v) } } + /** Get all parameters as a Map */ + def getAllAsMap: immutable.Map[String, String] = { + settings.asScala.toMap + } /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a77937efd7e1..4dd100623216 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -149,7 +149,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { paths = paths, userSpecifiedSchema = userSpecifiedSchema, className = source, - options = extraOptions.toMap).resolveRelation()) + options = optionsOverriddenWith(extraOptions.toMap)).resolveRelation()) } /** @@ -551,4 +551,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { private var extraOptions = new scala.collection.mutable.HashMap[String, String] + // Returns all option set in the `SparkConf`, the `SQLConf`, and a given data source `options`. + // If the same keys exist, they are overridden with ones in the `options`. + private def optionsOverriddenWith(options: Map[String, String]): Map[String, String] = { + sparkSession.sparkContext.conf.getAllAsMap ++ sparkSession.conf.getAll ++ options + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index e485b52b43f7..49450a062e09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -470,7 +470,8 @@ case class FileSourceScanExec( val defaultMaxSplitBytes = fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes - val defaultParallelism = fsRelation.sparkSession.sparkContext.defaultParallelism + val defaultParallelism = fsRelation.options.get("spark.default.parallelism").map(_.toInt) + .getOrElse(fsRelation.sparkSession.sparkContext.defaultParallelism) val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum val bytesPerCore = totalBytes / defaultParallelism diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 031a0fe57893..e02751e22943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -232,7 +232,7 @@ class ParquetFileFormat .orElse(filesByType.data.headOption) .toSeq } - ParquetFileFormat.mergeSchemasInParallel(filesToTouch, sparkSession) + ParquetFileFormat.mergeSchemasInParallel(filesToTouch, sparkSession, parameters) } case class FileTypes( @@ -561,9 +561,10 @@ object ParquetFileFormat extends Logging { * slow. And basically locality is not available when using S3 (you can't run computation on * S3 nodes). */ - def mergeSchemasInParallel( + private def mergeSchemasInParallel( filesToTouch: Seq[FileStatus], - sparkSession: SparkSession): Option[StructType] = { + sparkSession: SparkSession, + parameters: Map[String, String]): Option[StructType] = { val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat @@ -584,8 +585,9 @@ object ParquetFileFormat extends Logging { // Set the number of partitions to prevent following schema reads from generating many tasks // in case of a small number of parquet files. - val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1), - sparkSession.sparkContext.defaultParallelism) + val defaultParallelism = parameters.get("spark.default.parallelism").map(_.toInt) + .getOrElse(sparkSession.sparkContext.defaultParallelism) + val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1), defaultParallelism) // Issues a Spark job to read Parquet schema in parallel. val partiallyMergedSchemas = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index a81a95d51085..e51b7e5fc4c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -40,7 +40,8 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase + val codecName = parameters.getOrElse("compression", parameters.getOrElse( + "spark.sql.parquet.compression.codec", sqlConf.parquetCompressionCodec)).toLowerCase if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) throw new IllegalArgumentException(s"Codec [$codecName] " + @@ -55,8 +56,9 @@ private[parquet] class ParquetOptions( */ val mergeSchema: Boolean = parameters .get(MERGE_SCHEMA) - .map(_.toBoolean) - .getOrElse(sqlConf.isParquetSchemaMergingEnabled) + .map(_.toBoolean) + .getOrElse(parameters.get("spark.sql.parquet.mergeSchema").map(_.toBoolean) + .getOrElse(sqlConf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index d900ce7bb237..653e72fafbee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -274,6 +274,29 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("datasource-specific minPartitions") { + val table = + createTable( + files = Seq( + "file1" -> 1, + "file2" -> 1, + "file3" -> 1, + "file4" -> 1, + "file5" -> 1, + "file6" -> 1, + "file7" -> 1, + "file8" -> 1, + "file9" -> 1), + options = Map("spark.default.parallelism" -> "3")) + + checkScan(table.select('c1)) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 3) + assert(partitions(1).files.size == 3) + assert(partitions(2).files.size == 3) + } + } + test("Locality support for FileScanRDD") { val partition = FilePartition(0, Seq( PartitionedFile(InternalRow.empty, "fakePath0", 0, 10, Array("host0", "host1")), @@ -526,7 +549,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi */ def createTable( files: Seq[(String, Int)], - buckets: Int = 0): DataFrame = { + buckets: Int = 0, + options: Map[String, String] = Map.empty): DataFrame = { val tempDir = Utils.createTempDir() files.foreach { case (name, size) => @@ -537,6 +561,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val df = spark.read .format(classOf[TestFileFormat].getName) + .options(options) .load(tempDir.getCanonicalPath) if (buckets > 0) {