diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0b1965c438e2..d354b2320177 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1205,6 +1205,20 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + val DISABLED_V2_FILE_DATA_SOURCE_READERS = buildConf("spark.sql.disabledV2FileDataSourceReaders") + .internal() + .doc("A comma-separated list of file data source short names for which DataSourceReader" + + " is disabled. Reads from these sources will fall back to the V1 sources") + .stringConf + .createWithDefault("") + + val DISABLED_V2_FILE_DATA_SOURCE_WRITERS = buildConf("spark.sql.disabledV2FileDataSourceWriters") + .internal() + .doc("A comma-separated list of file data source short names for which DataSourceWriter" + + " is disabled. Writes to these sources will fall back to the V1 FileFormat") + .stringConf + .createWithDefault("") + val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .internal() .doc("A comma-separated list of fully qualified data source register class names for which" + @@ -1606,6 +1620,10 @@ class SQLConf extends Serializable with Logging { def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + def disabledV2FileDataSourceReader: String = getConf(DISABLED_V2_FILE_DATA_SOURCE_READERS) + + def disabledV2FileDataSourceWriter: String = getConf(DISABLED_V2_FILE_DATA_SOURCE_WRITERS) + def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) def disabledV2StreamingMicroBatchReaders: String = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index dcebdc39f0aa..6982ebb80cc9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -58,10 +58,16 @@ public class OrcColumnarBatchReader extends RecordReader { /** * The column IDs of the physical ORC file schema which are required by this reader. - * -1 means this required column doesn't exist in the ORC file. + * -1 means this required column is partition column, or it doesn't exist in the ORC file. */ private int[] requestedColIds; + /** + * The column IDs of the ORC file partition schema which are required by this reader. + * -1 means this required column doesn't exist in the ORC partition columns. + */ + private int[] requestedPartitionColIds; + // Record reader from ORC row batch. private org.apache.orc.RecordReader recordReader; @@ -143,25 +149,34 @@ public void initialize( /** * Initialize columnar batch by setting required schema and partition information. * With this information, this creates ColumnarBatch with the full schema. + * + * @param orcSchema Schema from ORC file reader. + * @param requiredFields All the fields that are required to return, including partition fields. + * @param requestedColIds Requested column ids from orcSchema. -1 if not existed. + * @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed. + * @param partitionValues Values of partition columns. */ public void initBatch( TypeDescription orcSchema, - int[] requestedColIds, StructField[] requiredFields, - StructType partitionSchema, + int[] requestedColIds, + int[] requestedPartitionColIds, InternalRow partitionValues) { batch = orcSchema.createRowBatch(capacity); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. - + assert(requiredFields.length == requestedColIds.length); + assert(requiredFields.length == requestedPartitionColIds.length); + // If a required column is also partition column, use partition value and don't read from file. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedPartitionColIds[i] != -1) { + requestedColIds[i] = -1; + } + } + this.requestedPartitionColIds = requestedPartitionColIds; this.requiredFields = requiredFields; this.requestedColIds = requestedColIds; - assert(requiredFields.length == requestedColIds.length); StructType resultSchema = new StructType(requiredFields); - for (StructField f : partitionSchema.fields()) { - resultSchema = resultSchema.add(f); - } - if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); @@ -169,22 +184,18 @@ public void initBatch( columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } - // Initialize the missing columns once. + // Initialize the missing columns and partition columns once. for (int i = 0; i < requiredFields.length; i++) { - if (requestedColIds[i] == -1) { + if (requestedPartitionColIds[i] != -1) { + ColumnVectorUtils.populate(columnVectors[i], + partitionValues, requestedPartitionColIds[i]); + columnVectors[i].setIsConstant(); + } else if (requestedColIds[i] == -1) { columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } - if (partitionValues.numFields() > 0) { - int partitionIdx = requiredFields.length; - for (int i = 0; i < partitionValues.numFields(); i++) { - ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); - columnVectors[i + partitionIdx].setIsConstant(); - } - } - columnarBatch = new ColumnarBatch(columnVectors); } else { // Just wrap the ORC column vector instead of copying it to Spark column vector. @@ -192,26 +203,22 @@ public void initBatch( for (int i = 0; i < requiredFields.length; i++) { DataType dt = requiredFields[i].dataType(); - int colId = requestedColIds[i]; - // Initialize the missing columns once. - if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); - missingCol.putNulls(0, capacity); - missingCol.setIsConstant(); - orcVectorWrappers[i] = missingCol; - } else { - orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); - } - } - - if (partitionValues.numFields() > 0) { - int partitionIdx = requiredFields.length; - for (int i = 0; i < partitionValues.numFields(); i++) { - DataType dt = partitionSchema.fields()[i].dataType(); + if (requestedPartitionColIds[i] != -1) { OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); - ColumnVectorUtils.populate(partitionCol, partitionValues, i); + ColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); partitionCol.setIsConstant(); - orcVectorWrappers[partitionIdx + i] = partitionCol; + orcVectorWrappers[i] = partitionCol; + } else { + int colId = requestedColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); + } } } @@ -233,6 +240,7 @@ private boolean nextBatch() throws IOException { if (!copyToSpark) { for (int i = 0; i < requiredFields.length; i++) { + // It is possible that.. if (requestedColIds[i] != -1) { ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); } @@ -248,7 +256,7 @@ private boolean nextBatch() throws IOException { StructField field = requiredFields[i]; WritableColumnVector toColumn = columnVectors[i]; - if (requestedColIds[i] >= 0) { + if (requestedColIds[i] != -1) { ColumnVector fromColumn = batch.cols[requestedColIds[i]]; if (fromColumn.isRepeating) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 53f44888ebaf..55421a2ed9db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,13 +29,13 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser} import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -190,35 +190,49 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { "read files of Hive data source directly.") } - val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) + val allPaths = (CaseInsensitiveMap(extraOptions.toMap).get("path") ++ paths).toSeq + val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf, allPaths) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance().asInstanceOf[DataSourceV2] - if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + + val (needToFallBackFileDataSourceV2, fallBackFileFormat) = ds match { + case f: FileDataSourceV2 => + val disabledV2Readers = + sparkSession.sessionState.conf.disabledV2FileDataSourceReader.split(",") + (disabledV2Readers.contains(f.shortName), f.fallBackFileFormat.getCanonicalName) + case _ => (false, source) + } + val supportsRead = ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema] + if (supportsRead && !needToFallBackFileDataSourceV2) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) val pathsOption = { val objectMapper = new ObjectMapper() DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) } + Dataset.ofRows(sparkSession, DataSourceV2Relation.create( ds, extraOptions.toMap ++ sessionOptions + pathsOption, userSpecifiedSchema = userSpecifiedSchema)) } else { - loadV1Source(paths: _*) + // In the following cases, we fall back to loading with V1: + // 1. The data source implements v2, but has no v2 implementation for read path. + // 2. The v2 reader of the data source is configured as disabled. + loadV1Source(fallBackFileFormat, paths: _*) } } else { - loadV1Source(paths: _*) + loadV1Source(source, paths: _*) } } - private def loadV1Source(paths: String*) = { + private def loadV1Source(className: String, paths: String*) = { // Code path for data source v1. sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = paths, userSpecifiedSchema = userSpecifiedSchema, - className = source, + className = className, options = extraOptions.toMap).resolveRelation()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 90bea2d676e2..488a14b01396 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -30,8 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, InsertIntoT import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType @@ -241,39 +240,47 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance() - ds match { - case ws: WriteSupport => - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = df.sparkSession.sessionState.conf)).asJava) - // Using a timestamp and a random UUID to distinguish different writing jobs. This is good - // enough as there won't be tons of writing jobs created at the same second. - val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) - .format(new Date()) + "-" + UUID.randomUUID() - val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) - if (writer.isPresent) { - runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.logicalPlan) - } - } + val (needToFallBackFileDataSourceV2, fallBackFileFormat) = ds match { + case f: FileDataSourceV2 => + val disabledV2Readers = + df.sparkSession.sessionState.conf.disabledV2FileDataSourceWriter.split(",") + (disabledV2Readers.contains(f.shortName), f.fallBackFileFormat.getCanonicalName) + case _ => (false, source) + } - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving - // as though it's a V1 source. - case _ => saveToV1Source() + if (ds.isInstanceOf[WriteSupport] && !needToFallBackFileDataSourceV2) { + val options = new DataSourceOptions((extraOptions ++ + DataSourceV2Utils.extractSessionConfigs( + ds = ds.asInstanceOf[DataSourceV2], + conf = df.sparkSession.sessionState.conf)).asJava) + // Using a timestamp and a random UUID to distinguish different writing jobs. This is good + // enough as there won't be tons of writing jobs created at the same second. + val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + .format(new Date()) + "-" + UUID.randomUUID() + val writer = ds.asInstanceOf[WriteSupport] + .createWriter(jobId, df.logicalPlan.schema, mode, options) + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get(), df.logicalPlan) + } + } + } else { + // In the following cases, we fall back to saving with V1: + // 1. The data source implements v2, but has no v2 implementation for write path. + // 2. The v2 writer of the data source is configured as disabled. + saveToV1Source(fallBackFileFormat) } } else { - saveToV1Source() + saveToV1Source(source) } } - private def saveToV1Source(): Unit = { + private def saveToV1Source(className: String): Unit = { // Code path for data source v1. runCommand(df.sparkSession, "save") { DataSource( sparkSession = df.sparkSession, - className = source, + className = className, partitionColumns = partitioningColumns.getOrElse(Nil), options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 08ff33afbba3..9a91cab095d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -368,8 +368,7 @@ case class FileSourceScanExec( val bucketed = selectedPartitions.flatMap { p => p.files.map { f => - val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen) - PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts) + PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) } }.groupBy { f => BucketingUtils @@ -396,107 +395,34 @@ case class FileSourceScanExec( readFile: (PartitionedFile) => Iterator[InternalRow], selectedPartitions: Seq[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { - val defaultMaxSplitBytes = - fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes + val maxSplitBytes = PartitionedFileUtil.maxSplitBytes(relation.sparkSession, selectedPartitions) val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes - val defaultParallelism = fsRelation.sparkSession.sparkContext.defaultParallelism - val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum - val bytesPerCore = totalBytes / defaultParallelism - - val maxSplitBytes = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + s"open cost is considered as scanning $openCostInBytes bytes.") val splitFiles = selectedPartitions.flatMap { partition => partition.files.flatMap { file => - val blockLocations = getBlockLocations(file) - if (fsRelation.fileFormat.isSplitable( - fsRelation.sparkSession, fsRelation.options, file.getPath)) { - (0L until file.getLen by maxSplitBytes).map { offset => - val remaining = file.getLen - offset - val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining - val hosts = getBlockHosts(blockLocations, offset, size) - PartitionedFile( - partition.values, file.getPath.toUri.toString, offset, size, hosts) - } - } else { - val hosts = getBlockHosts(blockLocations, 0, file.getLen) - Seq(PartitionedFile( - partition.values, file.getPath.toUri.toString, 0, file.getLen, hosts)) - } + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, relation.options, filePath) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values + ) } }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse) - val partitions = new ArrayBuffer[FilePartition] - val currentFiles = new ArrayBuffer[PartitionedFile] - var currentSize = 0L - - /** Close the current partition and move to the next. */ - def closePartition(): Unit = { - if (currentFiles.nonEmpty) { - val newPartition = - FilePartition( - partitions.size, - currentFiles.toArray.toSeq) // Copy to a new Array. - partitions += newPartition - } - currentFiles.clear() - currentSize = 0 - } - - // Assign files to partitions using "Next Fit Decreasing" - splitFiles.foreach { file => - if (currentSize + file.length > maxSplitBytes) { - closePartition() - } - // Add the given file to the current partition. - currentSize += file.length + openCostInBytes - currentFiles += file - } - closePartition() + val partitions = + FilePartitionUtil.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) new FileScanRDD(fsRelation.sparkSession, readFile, partitions) } - private def getBlockLocations(file: FileStatus): Array[BlockLocation] = file match { - case f: LocatedFileStatus => f.getBlockLocations - case f => Array.empty[BlockLocation] - } - - // Given locations of all blocks of a single file, `blockLocations`, and an `(offset, length)` - // pair that represents a segment of the same file, find out the block that contains the largest - // fraction the segment, and returns location hosts of that block. If no such block can be found, - // returns an empty array. - private def getBlockHosts( - blockLocations: Array[BlockLocation], offset: Long, length: Long): Array[String] = { - val candidates = blockLocations.map { - // The fragment starts from a position within this block - case b if b.getOffset <= offset && offset < b.getOffset + b.getLength => - b.getHosts -> (b.getOffset + b.getLength - offset).min(length) - - // The fragment ends at a position within this block - case b if offset <= b.getOffset && offset + length < b.getLength => - b.getHosts -> (offset + length - b.getOffset).min(length) - - // The fragment fully contains this block - case b if offset <= b.getOffset && b.getOffset + b.getLength <= offset + length => - b.getHosts -> b.getLength - - // The fragment doesn't intersect with this block - case b => - b.getHosts -> 0L - }.filter { case (hosts, size) => - size > 0L - } - - if (candidates.isEmpty) { - Array.empty[String] - } else { - val (hosts, _) = candidates.maxBy { case (_, size) => size } - hosts - } - } - override def doCanonicalize(): FileSourceScanExec = { FileSourceScanExec( relation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala new file mode 100644 index 000000000000..06585271c165 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{PartitionedFile, _} + +object PartitionedFileUtil { + def splitFiles( + sparkSession: SparkSession, + file: FileStatus, + filePath: Path, + isSplitable: Boolean, + maxSplitBytes: Long, + partitionValues: InternalRow): Seq[PartitionedFile] = { + if (isSplitable) { + (0L until file.getLen by maxSplitBytes).map { offset => + val remaining = file.getLen - offset + val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining + val hosts = getBlockHosts(getBlockLocations(file), offset, size) + PartitionedFile(partitionValues, filePath.toUri.toString, offset, size, hosts) + } + } else { + Seq(getPartitionedFile(file, filePath, partitionValues)) + } + } + + def getPartitionedFile( + file: FileStatus, + filePath: Path, + partitionValues: InternalRow + ): PartitionedFile = { + val hosts = getBlockHosts(getBlockLocations(file), 0, file.getLen) + PartitionedFile(partitionValues, filePath.toUri.toString, 0, file.getLen, hosts) + } + + def maxSplitBytes( + sparkSession: SparkSession, + selectedPartitions: Seq[PartitionDirectory]): Long = { + val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes + val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes + val defaultParallelism = sparkSession.sparkContext.defaultParallelism + val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum + val bytesPerCore = totalBytes / defaultParallelism + + Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) + } + + private def getBlockLocations(file: FileStatus): Array[BlockLocation] = file match { + case f: LocatedFileStatus => f.getBlockLocations + case f => Array.empty[BlockLocation] + } + // Given locations of all blocks of a single file, `blockLocations`, and an `(offset, length)` + // pair that represents a segment of the same file, find out the block that contains the largest + // fraction the segment, and returns location hosts of that block. If no such block can be found, + // returns an empty array. + private def getBlockHosts( + blockLocations: Array[BlockLocation], + offset: Long, + length: Long): Array[String] = { + val candidates = blockLocations.map { + // The fragment starts from a position within this block + case b if b.getOffset <= offset && offset < b.getOffset + b.getLength => + b.getHosts -> (b.getOffset + b.getLength - offset).min(length) + + // The fragment ends at a position within this block + case b if offset <= b.getOffset && offset + length < b.getLength => + b.getHosts -> (offset + length - b.getOffset).min(length) + + // The fragment fully contains this block + case b if offset <= b.getOffset && b.getOffset + b.getLength <= offset + length => + b.getHosts -> b.getLength + + // The fragment doesn't intersect with this block + case b => + b.getHosts -> 0L + }.filter { case (hosts, size) => + size > 0L + } + + if (candidates.isEmpty) { + Array.empty[String] + } else { + val (hosts, _) = candidates.maxBy { case (_, size) => size } + hosts + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 44749190c79e..0da1c2630dfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -236,7 +237,7 @@ case class AlterTableAddColumnsCommand( // TextFileFormat only default to one column "value" // Hive type is already considered as hive serde table, so the logic will not // come in here. - case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat => + case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat | _: OrcDataSourceV2 => case s if s.getClass.getCanonicalName.endsWith("OrcFileFormat") => case s => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e7..61e1a8da27b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.execution.datasources import java.util.{Locale, ServiceConfigurationError, ServiceLoader} +import javax.activation.FileDataSource import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -39,6 +41,8 @@ import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf @@ -89,10 +93,19 @@ case class DataSource( case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) - lazy val providingClass: Class[_] = - DataSource.lookupDataSource(className, sparkSession.sessionState.conf) + lazy val providingClass: Class[_] = { + val cls = DataSource.lookupDataSource(className, sparkSession.sessionState.conf) + // Here `providingClass` is supposed to be V1 file format. Currently [[FileDataSourceV2]] + // doesn't support catalog, so creating tables with V2 file format still uses this code path. + // As a temporary hack to avoid failure, [[FileDataSourceV2]] is falled back to [[FileFormat]]. + cls.newInstance() match { + case f: FileDataSourceV2 => f.fallBackFileFormat + case _ => cls + } + } lazy val sourceInfo: SourceInfo = sourceSchema() private val caseInsensitiveOptions = CaseInsensitiveMap(options) + private val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis private val equality = sparkSession.sessionState.conf.resolver bucketSpec.map { bucket => @@ -426,7 +439,6 @@ case class DataSource( s"got: ${allPaths.mkString(", ")}") } - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) val fileIndex = catalogTable.map(_.identifier).map { tableIdent => @@ -537,23 +549,8 @@ case class DataSource( checkFilesExist: Boolean): Seq[Path] = { val allPaths = caseInsensitiveOptions.get("path") ++ paths val hadoopConf = sparkSession.sessionState.newHadoopConf() - allPaths.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - - if (checkEmptyGlobPath && globPath.isEmpty) { - throw new AnalysisException(s"Path does not exist: $qualified") - } - - // Sufficient to check head of the globPath seq for non-glob scenario - // Don't need to check once again if files exist in streaming mode - if (checkFilesExist && !fs.exists(globPath.head)) { - throw new AnalysisException(s"Path does not exist: ${globPath.head}") - } - globPath - }.toSeq + DataSource.checkAndGlobPathIfNecessary(allPaths.toSeq, hadoopConf, + checkEmptyGlobPath, checkFilesExist) } } @@ -605,11 +602,14 @@ object DataSource extends Logging { "org.apache.spark.Logging") /** Given a provider name, look up the data source class definition. */ - def lookupDataSource(provider: String, conf: SQLConf): Class[_] = { + def lookupDataSource( + provider: String, + conf: SQLConf, + paths: Seq[String] = Seq.empty): Class[_] = { val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match { case name if name.equalsIgnoreCase("orc") && conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native" => - classOf[OrcFileFormat].getCanonicalName + classOf[OrcDataSourceV2].getCanonicalName case name if name.equalsIgnoreCase("orc") && conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" => "org.apache.spark.sql.hive.orc.OrcFileFormat" @@ -690,6 +690,33 @@ object DataSource extends Logging { } } + /** + * Checks and returns files in all the paths. + */ + private[sql] def checkAndGlobPathIfNecessary( + paths: Seq[String], + hadoopConf: Configuration, + checkEmptyGlobPath: Boolean, + checkFilesExist: Boolean): Seq[Path] = { + paths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) + + if (checkEmptyGlobPath && globPath.isEmpty) { + throw new AnalysisException(s"Path does not exist: $qualified") + } + + // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode + if (checkFilesExist && !fs.exists(globPath.head)) { + throw new AnalysisException(s"Path does not exist: ${globPath.head}") + } + globPath + } + } + /** * When creating a data source table, the `path` option has a special meaning: the table location. * This method extracts the `path` option and treat it as table location to build a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 3f41612c0806..e4fdd12316a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale import java.util.concurrent.Callable +import javax.activation.FileDataSource import org.apache.hadoop.fs.Path @@ -37,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileDataSourceV2} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -213,6 +215,26 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast } } +/** + * Replace the V2 data source of table in [[InsertIntoTable]] to V1 [[FileFormat]]. + * E.g, with temporary view `t` using [[FileDataSourceV2]], inserting into view `t` fails + * since there is no correspoding physical plan. + * This is a temporary hack for making current data source V2 work. + */ +class FallBackFileDataSourceToV1(sparkSession: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case i @InsertIntoTable(d: DataSourceV2Relation, _, _, _, _) + if d.source.isInstanceOf[FileDataSourceV2] => + val v1FileFormat = d.source.asInstanceOf[FileDataSourceV2].fallBackFileFormat + val v1 = DataSource.apply( + sparkSession = sparkSession, + paths = d.v2Options.paths(), + userSpecifiedSchema = d.userSpecifiedSchema, + className = v1FileFormat.getCanonicalName, + options = d.options - "path").resolveRelation() + i.copy(table = LogicalRelation(v1)) + } +} /** * Replaces [[UnresolvedCatalogRelation]] with concrete relation logical plans. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartitionUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartitionUtil.scala new file mode 100644 index 000000000000..e9210c195ed1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartitionUtil.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import java.io.{FileNotFoundException, IOException} + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.parquet.io.ParquetDecodingException + +import org.apache.spark.TaskContext +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.InputFileBlockHolder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.NextIterator + +object FilePartitionUtil extends Logging { + + def getFilePartitions( + sparkSession: SparkSession, + partitionedFiles: Seq[PartitionedFile], + maxSplitBytes: Long): Seq[FilePartition] = { + val partitions = new ArrayBuffer[FilePartition] + val currentFiles = new ArrayBuffer[PartitionedFile] + var currentSize = 0L + + /** Close the current partition and move to the next. */ + def closePartition(): Unit = { + if (currentFiles.nonEmpty) { + val newPartition = + FilePartition( + partitions.size, + currentFiles.toArray.toSeq) // Copy to a new Array. + partitions += newPartition + } + currentFiles.clear() + currentSize = 0 + } + + val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes + // Assign files to partitions using "Next Fit Decreasing" + partitionedFiles.foreach { file => + if (currentSize + file.length > maxSplitBytes) { + closePartition() + } + // Add the given file to the current partition. + currentSize += file.length + openCostInBytes + currentFiles += file + } + closePartition() + partitions + } + + def compute( + split: FilePartition, + context: TaskContext, + readFunction: (PartitionedFile) => Iterator[InternalRow], + ignoreCorruptFiles: Boolean = false, + ignoreMissingFiles: Boolean = false): Iterator[InternalRow] = { + val iterator = new Iterator[Object] with AutoCloseable { + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead + + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // apply readFunction, because it might read some bytes. + private val getBytesReadCallback = + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + + // We get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). + private def updateBytesRead(): Unit = { + inputMetrics.setBytesRead(existingBytesRead + getBytesReadCallback()) + } + + // If we can't get the bytes read from the FS stats, fall back to the file size, + // which may be inaccurate. + private def updateBytesReadWithFileSize(): Unit = { + if (currentFile != null) { + inputMetrics.incBytesRead(currentFile.length) + } + } + + private[this] val files = split.asInstanceOf[FilePartition].files.toIterator + private[this] var currentFile: PartitionedFile = null + private[this] var currentIterator: Iterator[Object] = null + + def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + context.killTaskIfInterrupted() + (currentIterator != null && currentIterator.hasNext) || nextIterator() + } + def next(): Object = { + val nextElement = currentIterator.next() + // TODO: we should have a better separation of row based and batch based scan, so that we + // don't need to run this `if` for every record. + if (nextElement.isInstanceOf[ColumnarBatch]) { + inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) + } else { + inputMetrics.incRecordsRead(1) + } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } + nextElement + } + + private def readCurrentFile(): Iterator[InternalRow] = { + try { + readFunction(currentFile) + } catch { + case e: FileNotFoundException => + throw new FileNotFoundException( + e.getMessage + "\n" + + "It is possible the underlying files have been updated. " + + "You can explicitly invalidate the cache in Spark by " + + "running 'REFRESH TABLE tableName' command in SQL or " + + "by recreating the Dataset/DataFrame involved.") + } + } + + /** Advances to the next file. Returns true if a new non-empty iterator is available. */ + private def nextIterator(): Boolean = { + updateBytesReadWithFileSize() + if (files.hasNext) { + currentFile = files.next() + logInfo(s"Reading File $currentFile") + // Sets InputFileBlockHolder for the file block's information + InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) + + if (ignoreMissingFiles || ignoreCorruptFiles) { + currentIterator = new NextIterator[Object] { + // The readFunction may read some bytes before consuming the iterator, e.g., + // vectorized Parquet reader. Here we use lazy val to delay the creation of + // iterator so that we will throw exception in `getNext`. + private lazy val internalIter = readCurrentFile() + + override def getNext(): AnyRef = { + try { + if (internalIter.hasNext) { + internalIter.next() + } else { + finished = true + null + } + } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: $currentFile", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e + case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles => + logWarning( + s"Skipped the rest of the content in the corrupted file: $currentFile", e) + finished = true + null + } + } + + override def close(): Unit = {} + } + } else { + currentIterator = readCurrentFile() + } + + try { + hasNext + } catch { + case e: SchemaColumnConvertNotSupportedException => + val message = "Parquet column cannot be converted in " + + s"file ${currentFile.filePath}. Column: ${e.getColumn}, " + + s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" + throw new QueryExecutionException(message, e) + case e: ParquetDecodingException => + if (e.getMessage.contains("Can not read value at")) { + val message = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the " + + "corresponding files. Details: " + throw new QueryExecutionException(message, e) + } + throw e + } + } else { + currentFile = null + InputFileBlockHolder.unset() + false + } + } + + override def close(): Unit = { + updateBytesRead() + updateBytesReadWithFileSize() + InputFileBlockHolder.unset() + } + } + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener(_ => iterator.close()) + + iterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. + } + + def getPreferredLocations(split: FilePartition): Array[String] = { + val files = split.files + + // Computes total number of bytes can be retrieved from each host. + val hostToNumBytes = mutable.HashMap.empty[String, Long] + files.foreach { file => + file.locations.filter(_ != "localhost").foreach { host => + hostToNumBytes(host) = hostToNumBytes.getOrElse(host, 0L) + file.length + } + } + + // Takes the first 3 hosts with the most data to be retrieved + hostToNumBytes.toSeq.sortBy { + case (host, numBytes) => numBytes + }.reverse.take(3).map { + case (host, numBytes) => host + }.toArray + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 28c36b6020d3..ddc3ec7c23d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -17,20 +17,10 @@ package org.apache.spark.sql.execution.datasources -import java.io.{FileNotFoundException, IOException} - -import scala.collection.mutable - -import org.apache.parquet.io.ParquetDecodingException - -import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.{InputFileBlockHolder, RDD} +import org.apache.spark.{Partition => RDDPartition, TaskContext} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.QueryExecutionException -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.NextIterator /** * A part (i.e. "block") of a single file that should be read, along with partition column values @@ -72,171 +62,13 @@ class FileScanRDD( private val ignoreMissingFiles = sparkSession.sessionState.conf.ignoreMissingFiles override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { - val iterator = new Iterator[Object] with AutoCloseable { - private val inputMetrics = context.taskMetrics().inputMetrics - private val existingBytesRead = inputMetrics.bytesRead - - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // apply readFunction, because it might read some bytes. - private val getBytesReadCallback = - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - - // We get our input bytes from thread-local Hadoop FileSystem statistics. - // If we do a coalesce, however, we are likely to compute multiple partitions in the same - // task and in the same thread, in which case we need to avoid override values written by - // previous partitions (SPARK-13071). - private def updateBytesRead(): Unit = { - inputMetrics.setBytesRead(existingBytesRead + getBytesReadCallback()) - } - - // If we can't get the bytes read from the FS stats, fall back to the file size, - // which may be inaccurate. - private def updateBytesReadWithFileSize(): Unit = { - if (currentFile != null) { - inputMetrics.incBytesRead(currentFile.length) - } - } - - private[this] val files = split.asInstanceOf[FilePartition].files.toIterator - private[this] var currentFile: PartitionedFile = null - private[this] var currentIterator: Iterator[Object] = null - - def hasNext: Boolean = { - // Kill the task in case it has been marked as killed. This logic is from - // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order - // to avoid performance overhead. - context.killTaskIfInterrupted() - (currentIterator != null && currentIterator.hasNext) || nextIterator() - } - def next(): Object = { - val nextElement = currentIterator.next() - // TODO: we should have a better separation of row based and batch based scan, so that we - // don't need to run this `if` for every record. - if (nextElement.isInstanceOf[ColumnarBatch]) { - inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) - } else { - inputMetrics.incRecordsRead(1) - } - if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { - updateBytesRead() - } - nextElement - } - - private def readCurrentFile(): Iterator[InternalRow] = { - try { - readFunction(currentFile) - } catch { - case e: FileNotFoundException => - throw new FileNotFoundException( - e.getMessage + "\n" + - "It is possible the underlying files have been updated. " + - "You can explicitly invalidate the cache in Spark by " + - "running 'REFRESH TABLE tableName' command in SQL or " + - "by recreating the Dataset/DataFrame involved.") - } - } - - /** Advances to the next file. Returns true if a new non-empty iterator is available. */ - private def nextIterator(): Boolean = { - updateBytesReadWithFileSize() - if (files.hasNext) { - currentFile = files.next() - logInfo(s"Reading File $currentFile") - // Sets InputFileBlockHolder for the file block's information - InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) - - if (ignoreMissingFiles || ignoreCorruptFiles) { - currentIterator = new NextIterator[Object] { - // The readFunction may read some bytes before consuming the iterator, e.g., - // vectorized Parquet reader. Here we use lazy val to delay the creation of - // iterator so that we will throw exception in `getNext`. - private lazy val internalIter = readCurrentFile() - - override def getNext(): AnyRef = { - try { - if (internalIter.hasNext) { - internalIter.next() - } else { - finished = true - null - } - } catch { - case e: FileNotFoundException if ignoreMissingFiles => - logWarning(s"Skipped missing file: $currentFile", e) - finished = true - null - // Throw FileNotFoundException even if `ignoreCorruptFiles` is true - case e: FileNotFoundException if !ignoreMissingFiles => throw e - case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles => - logWarning( - s"Skipped the rest of the content in the corrupted file: $currentFile", e) - finished = true - null - } - } - - override def close(): Unit = {} - } - } else { - currentIterator = readCurrentFile() - } - - try { - hasNext - } catch { - case e: SchemaColumnConvertNotSupportedException => - val message = "Parquet column cannot be converted in " + - s"file ${currentFile.filePath}. Column: ${e.getColumn}, " + - s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" - throw new QueryExecutionException(message, e) - case e: ParquetDecodingException => - if (e.getMessage.contains("Can not read value at")) { - val message = "Encounter error while reading parquet files. " + - "One possible cause: Parquet column cannot be converted in the " + - "corresponding files. Details: " - throw new QueryExecutionException(message, e) - } - throw e - } - } else { - currentFile = null - InputFileBlockHolder.unset() - false - } - } - - override def close(): Unit = { - updateBytesRead() - updateBytesReadWithFileSize() - InputFileBlockHolder.unset() - } - } - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(_ => iterator.close()) - - iterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. + FilePartitionUtil.compute(split.asInstanceOf[FilePartition], context, + readFunction, ignoreCorruptFiles, ignoreMissingFiles) } override protected def getPartitions: Array[RDDPartition] = filePartitions.toArray override protected def getPreferredLocations(split: RDDPartition): Seq[String] = { - val files = split.asInstanceOf[FilePartition].files - - // Computes total number of bytes can be retrieved from each host. - val hostToNumBytes = mutable.HashMap.empty[String, Long] - files.foreach { file => - file.locations.filter(_ != "localhost").foreach { host => - hostToNumBytes(host) = hostToNumBytes.getOrElse(host, 0L) + file.length - } - } - - // Takes the first 3 hosts with the most data to be retrieved - hostToNumBytes.toSeq.sortBy { - case (host, numBytes) => numBytes - }.reverse.take(3).map { - case (host, numBytes) => host - } + FilePartitionUtil.getPreferredLocations(split.asInstanceOf[FilePartition]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index b2f73b7f8d1f..d278802e6c9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -52,28 +52,12 @@ case class HadoopFsRelation( override def sqlContext: SQLContext = sparkSession.sqlContext - private def getColName(f: StructField): String = { - if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { - f.name - } else { - f.name.toLowerCase(Locale.ROOT) - } - } - - val overlappedPartCols = mutable.Map.empty[String, StructField] - partitionSchema.foreach { partitionField => - if (dataSchema.exists(getColName(_) == getColName(partitionField))) { - overlappedPartCols += getColName(partitionField) -> partitionField - } - } - // When data and partition schemas have overlapping columns, the output // schema respects the order of the data schema for the overlapping columns, and it // respects the data types of the partition schema. - val schema: StructType = { - StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ - partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) - } + val (schema: StructType, overlappedPartCols: Map[String, StructField]) = + PartitioningUtils.mergeDataAndPartitionSchema(dataSchema, + partitionSchema, sparkSession.sessionState.conf.caseSensitiveAnalysis) def partitionSchemaOption: Option[StructType] = if (partitionSchema.isEmpty) None else Some(partitionSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 1edf27619ad7..862ee64de9b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -21,6 +21,7 @@ import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.util.{Locale, TimeZone} +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -493,6 +494,61 @@ object PartitioningUtils { }).asNullable } + def mergeDataAndPartitionSchema( + dataSchema: StructType, + partitionSchema: StructType, + caseSensitive: Boolean): (StructType, Map[String, StructField]) = { + val equality = columnNameEquality(caseSensitive) + val overlappedPartCols = mutable.Map.empty[String, StructField] + partitionSchema.foreach { partitionField => + val partitionFieldName = getColName(partitionField, caseSensitive) + if (dataSchema.exists(getColName(_, caseSensitive) == partitionFieldName)) { + overlappedPartCols += partitionFieldName -> partitionField + } + } + + // When data and partition schemas have overlapping columns, the output + // schema respects the order of the data schema for the overlapping columns, and it + // respects the data types of the partition schema. + val fullSchema = + StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f, caseSensitive), f)) ++ + partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f, caseSensitive)))) + (fullSchema, overlappedPartCols.toMap) + } + + def requestedPartitionColumnIds( + partitionSchema: StructType, + requiredSchema: StructType, + caseSensitive: Boolean): Array[Int] = { + val columnNameMap = + partitionSchema.fields.map(getColName(_, caseSensitive)).zipWithIndex.toMap + requiredSchema.fields.map { field => + columnNameMap.getOrElse(getColName(field, caseSensitive), -1) + } + } + + /** + * Returns a new StructType that is a copy of the original StructType, removing any items that + * also appear in other StructType. The order is preserved from the original StructType. + */ + def subtractSchema(original: StructType, other: StructType, isCaseSensitive: Boolean) + : StructType = { + val otherNameSet = other.fields.map(getColName(_, isCaseSensitive)).toSet + val fields = original.fields.filterNot { field => + otherNameSet.contains(getColName(field, isCaseSensitive)) + } + + StructType(fields) + } + + private def getColName(f: StructField, caseSensitive: Boolean): String = { + if (caseSensitive) { + f.name + } else { + f.name.toLowerCase(Locale.ROOT) + } + } + private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = { if (caseSensitive) { org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 1de2ca2914c4..835dd0315d90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -173,8 +173,9 @@ class OrcFileFormat if (requestedColIdsOrEmptyFile.isEmpty) { Iterator.empty } else { - val requestedColIds = requestedColIdsOrEmptyFile.get - assert(requestedColIds.length == requiredSchema.length, + val requestedColIds = + requestedColIdsOrEmptyFile.get ++ Array.fill(partitionSchema.length)(-1) + assert(requestedColIds.length == resultSchema.length, "[BUG] requested column IDs do not match required schema") val taskConf = new Configuration(conf) taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, @@ -193,13 +194,14 @@ class OrcFileFormat // after opening a file. val iter = new RecordReaderIterator(batchReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) - + val requestedPartitionColIds = + Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, + resultSchema.fields, requestedColIds, - requiredSchema.fields, - partitionSchema, + requestedPartitionColIds, file.partitionValues) iter.asInstanceOf[Iterator[InternalRow]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4f44ae4fa1d7..7cac51828f8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -54,7 +54,7 @@ import org.apache.spark.sql.types._ * builder methods mentioned above can only be found in test code, where all tested filters are * known to be convertible. */ -private[orc] object OrcFilters { +object OrcFilters { /** * Create ORC filter as a SearchArgument instance. @@ -64,19 +64,23 @@ private[orc] object OrcFilters { // First, tries to convert each filter individually to see whether it's convertible, and then // collect all convertible ones to build the final `SearchArgument`. - val convertibleFilters = for { - filter <- filters - _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) - } yield filter + val convertible = convertibleFilters(schema, filters) for { // Combines all convertible filters using `And` to produce a single conjunction - conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And) + conjunction <- convertible.reduceOption(org.apache.spark.sql.sources.And) // Then tries to build a single ORC `SearchArgument` for the conjunction predicate builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) } yield builder.build() } + def convertibleFilters(schema: StructType, filters: Seq[Filter]): Seq[Filter] = { + val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + for { + filter <- filters + _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) + } yield filter + } /** * Return true if this is a searchable type in ORC. * Both CharType and VarcharType are cleaned at AstBuilder. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 90fb5a14c9fc..83ef7e534517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -52,7 +52,7 @@ case class DataSourceV2Relation( AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) } - private lazy val v2Options: DataSourceOptions = makeV2Options(options) + private[sql] lazy val v2Options: DataSourceOptions = makeV2Options(options) // postScanFilters: filters that need to be evaluated after the scan. // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/EmptyInputPartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/EmptyInputPartitionReader.scala new file mode 100644 index 000000000000..bf570a4ba07a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/EmptyInputPartitionReader.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.io.IOException + +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader + +/** + * A [[InputPartitionReader]] with empty output. + */ +class EmptyInputPartitionReader[T] extends InputPartitionReader[T] { + override def next(): Boolean = false + + override def get(): T = + throw new IOException("No records should be returned from EmptyDataReader") + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala new file mode 100644 index 000000000000..82218d5eeec2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.ReadSupport +import org.apache.spark.sql.sources.v2.WriteSupport + +/** + * A base interface for data source v2 implementations of the built-in file-based data sources. + */ +trait FileDataSourceV2 extends DataSourceV2 with DataSourceRegister { + /** + * Returns a V1 [[FileFormat]] class of the same file data source. + * This is a solution for the following cases: + * 1. File datasource V2 might be implemented partially during migration. + * E.g. if [[ReadSupport]] is implemented while [[WriteSupport]] is not, + * write path should fall back to V1 implementation. + * 2. File datasource V2 implementations cause regression. + * 3. Catalog support is required, which is still under development for data source V2. + */ + def fallBackFileFormat: Class[_ <: FileFormat] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileInputPartition.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileInputPartition.scala new file mode 100644 index 000000000000..9399c98e2710 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileInputPartition.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.io.{FileNotFoundException, IOException} + +import org.apache.spark.TaskContext +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.InputFileBlockHolder +import org.apache.spark.sql.execution.datasources.{FilePartition, FilePartitionUtil, PartitionedFile} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class FileInputPartition[T]( + file: FilePartition, + readFunction: (PartitionedFile) => InputPartitionReader[T], + ignoreCorruptFiles: Boolean = false, + ignoreMissingFiles: Boolean = false) + extends InputPartition[T] { + override def createPartitionReader(): InputPartitionReader[T] = { + val taskContext = TaskContext.get() + val iter = file.files.iterator.map(f => PartitionedFileReader(f, readFunction(f))) + FileInputPartitionReader(taskContext, iter, ignoreCorruptFiles, ignoreMissingFiles) + } + + override def preferredLocations(): Array[String] = { + FilePartitionUtil.getPreferredLocations(file) + } +} + +case class PartitionedFileReader[T]( + file: PartitionedFile, + reader: InputPartitionReader[T]) extends InputPartitionReader[T] { + override def next(): Boolean = reader.next() + + override def get(): T = reader.get() + + override def close(): Unit = reader.close() +} + +case class FileInputPartitionReader[T]( + context: TaskContext, + readers: Iterator[PartitionedFileReader[T]], + ignoreCorruptFiles: Boolean, + ignoreMissingFiles: Boolean) extends InputPartitionReader[T] with Logging { + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead + + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // apply readFunction, because it might read some bytes. + private val getBytesReadCallback = + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + + // We get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). + private def updateBytesRead(): Unit = { + inputMetrics.setBytesRead(existingBytesRead + getBytesReadCallback()) + } + + // If we can't get the bytes read from the FS stats, fall back to the file size, + // which may be inaccurate. + private def updateBytesReadWithFileSize(): Unit = { + if (currentFile != null) { + inputMetrics.incBytesRead(currentFile.file.length) + } + } + + private[this] var currentFile: PartitionedFileReader[T] = null + + private def hasNext(): Boolean = { + if (currentFile == null) { + if (readers.hasNext) { + if (ignoreMissingFiles || ignoreCorruptFiles) { + try { + currentFile = readers.next() + } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: $currentFile", e) + currentFile = null + return false + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e + case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles => + logWarning( + s"Skipped the rest of the content in the corrupted file: $currentFile", e) + currentFile = null + return false + } + } else { + currentFile = readers.next() + } + } else { + return false + } + } + if (currentFile.next()) { + return true + } else { + close() + currentFile = null + } + hasNext() + } + + override def next(): Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + context.killTaskIfInterrupted() + + hasNext() + } + + override def get(): T = { + val nextElement = currentFile.get() + // TODO: we should have a better separation of row based and batch based scan, so that we + // don't need to run this `if` for every record. + if (nextElement.isInstanceOf[ColumnarBatch]) { + inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) + } else { + inputMetrics.incRecordsRead(1) + } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } + nextElement + } + + override def close(): Unit = { + updateBytesRead() + updateBytesReadWithFileSize() + InputFileBlockHolder.unset() + if (currentFile != null) { + currentFile.close() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileSourceReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileSourceReader.scala new file mode 100644 index 000000000000..c4ab843a22b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileSourceReader.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow} +import org.apache.spark.sql.execution.PartitionedFileUtil +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +abstract class FileSourceReader(options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) + extends DataSourceReader + with SupportsScanUnsafeRow + with SupportsPushDownRequiredColumns + with SupportsPushDownCatalystFilters { + /** + * When possible, this method should return the schema of the given `files`. When the format + * does not support inference, or no valid files are given should return None. In these cases + * Spark will require that user specify the schema manually. + */ + def inferSchema(files: Seq[FileStatus]): Option[StructType] + + /** + * Returns whether a file with `path` could be split or not. + */ + def isSplitable(path: Path): Boolean = { + false + } + + /** + * Returns a function that can be used to read a single file in as an [[InputPartitionReader]] of + * [[UnsafeRow]]. + */ + def unsafeInputPartitionReader: PartitionedFile => InputPartitionReader[UnsafeRow] + + protected val sparkSession = SparkSession.getActiveSession + .getOrElse(SparkSession.getDefaultSession.get) + protected val hadoopConf = + sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) + protected val sqlConf = sparkSession.sessionState.conf + + protected val isCaseSensitive = sqlConf.caseSensitiveAnalysis + protected val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles + protected val ignoreMissingFiles = sqlConf.ignoreMissingFiles + private lazy val rootPathsSpecified = { + val filePaths = options.paths() + if (filePaths.isEmpty) { + throw new AnalysisException("Reading data source requires a" + + " path (e.g. data backed by a local or distributed file system).") + } + DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf, + checkEmptyGlobPath = false, checkFilesExist = false) + } + + protected lazy val fileIndex = { + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex(sparkSession, rootPathsSpecified, + options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache) + } + + protected lazy val partitionSchema = fileIndex.partitionSchema + + protected lazy val dataSchema = userSpecifiedSchema.orElse { + inferSchema(fileIndex.allFiles()) + }.getOrElse { + throw new AnalysisException( + s"Unable to infer schema for $rootPathsSpecified. It must be specified manually.") + } + protected val (fullSchema, _) = + PartitioningUtils.mergeDataAndPartitionSchema(dataSchema, partitionSchema, isCaseSensitive) + protected var requiredSchema = fullSchema + protected var partitionFilters: Array[Expression] = Array.empty + protected var pushedFiltersArray: Array[Expression] = Array.empty + + protected def partitions: Seq[FilePartition] = { + val selectedPartitions = fileIndex.listFiles(partitionFilters, Seq.empty) + val maxSplitBytes = PartitionedFileUtil.maxSplitBytes(sparkSession, selectedPartitions) + val splitFiles = selectedPartitions.flatMap { partition => + partition.files.flatMap { file => + val filePath = file.getPath + PartitionedFileUtil.splitFiles( + sparkSession = sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable(filePath), + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values + ) + }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse) + } + FilePartitionUtil.getFilePartitions(sparkSession, splitFiles, maxSplitBytes) + } + + override def readSchema(): StructType = { + requiredSchema + } + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + override def pushCatalystFilters(filters: Array[Expression]): Array[Expression] = Array.empty + + override def pushedCatalystFilters(): Array[Expression] = { + pushedFiltersArray + } + + override def planUnsafeInputPartitions: JList[InputPartition[UnsafeRow]] = { + partitions.map { filePartition => + new FileInputPartition[UnsafeRow](filePartition, unsafeInputPartitionReader, + ignoreCorruptFiles, ignoreMissingFiles) + .asInstanceOf[InputPartition[UnsafeRow]] + }.asJava + } +} + +abstract class ColumnarBatchFileSourceReader( + options: DataSourceOptions, + userSpecifiedSchema: Option[StructType]) + extends FileSourceReader(options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) + with SupportsScanColumnarBatch { + /** + * Returns a function that can be used to read a single file in as an [[InputPartitionReader]] of + * [[ColumnarBatch]]. + */ + def columnarBatchInputPartitionReader: PartitionedFile => InputPartitionReader[ColumnarBatch] + + override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { + partitions.map { filePartition => + new FileInputPartition[ColumnarBatch](filePartition, columnarBatchInputPartitionReader, + ignoreCorruptFiles, ignoreMissingFiles) + .asInstanceOf[InputPartition[ColumnarBatch]] + }.asJava + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionRecordReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionRecordReader.scala new file mode 100644 index 000000000000..a0404a21f845 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionRecordReader.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.hadoop.mapreduce.RecordReader + +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader + +class PartitionRecordReader[T]( + private[this] var rowReader: RecordReader[_, T]) extends InputPartitionReader[T] { + override def next(): Boolean = rowReader.nextKeyValue() + + override def get(): T = rowReader.getCurrentValue + + override def close(): Unit = rowReader.close() +} + +class PartitionRecordDReaderWithProject[X, T]( + private[this] var rowReader: RecordReader[_, X], + project: X => T) extends InputPartitionReader[T] { + override def next(): Boolean = rowReader.nextKeyValue() + + override def get(): T = project(rowReader.getCurrentValue) + + override def close(): Unit = rowReader.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala new file mode 100644 index 000000000000..08e1e4d30e3a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.orc + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc.{OrcConf, OrcFile} +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce.OrcInputFormat + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.orc._ +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.{DataSourceOptions, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.types.{AtomicType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration + +class OrcDataSourceV2 extends FileDataSourceV2 with ReadSupport with ReadSupportWithSchema { + override def createReader(options: DataSourceOptions): DataSourceReader = { + new OrcDataSourceReader(options, None) + } + + override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = { + new OrcDataSourceReader(options, Some(schema)) + } + + override def fallBackFileFormat: Class[_ <: FileFormat] = classOf[OrcFileFormat] + + override def shortName(): String = "orc" +} + +case class OrcDataSourceReader(options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) + extends ColumnarBatchFileSourceReader(options: DataSourceOptions, + userSpecifiedSchema: Option[StructType]) { + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + OrcUtils.readSchema(sparkSession, files) + } + + override def pushCatalystFilters(filters: Array[Expression]): Array[Expression] = { + val partitionColumnNames = partitionSchema.toAttributes.map(_.name).toSet + val (partitionKeyFilters, otherFilters) = filters.partition { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + this.partitionFilters = partitionKeyFilters + pushedFiltersArray = partitionKeyFilters + if (sqlConf.orcFilterPushDown) { + val dataFilters = otherFilters.map { f => + (DataSourceStrategy.translateFilter(f), f) + }.collect { case (optionalFilter, catalystFilter) if optionalFilter.isDefined => + (optionalFilter.get, catalystFilter) + }.toMap + val pushedDataFilters = + OrcFilters.convertibleFilters(fullSchema, dataFilters.keys.toSeq).map(dataFilters).toArray + pushedFiltersArray ++= pushedDataFilters + OrcFilters.createFilter(fullSchema, dataFilters.keys.toSeq).foreach { f => + OrcInputFormat.setSearchArgument(hadoopConf, f, fullSchema.fieldNames) + } + } + otherFilters + } + + override def enableBatchRead(): Boolean = { + val schema = readSchema() + sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled && + schema.length <= sqlConf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + + override def isSplitable(path: Path): Boolean = true + + override def columnarBatchInputPartitionReader: + (PartitionedFile) => InputPartitionReader[ColumnarBatch] = { + val capacity = sqlConf.orcVectorizedReaderBatchSize + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val copyToSpark = sqlConf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) + val isCaseSensitive = this.isCaseSensitive + val dataSchema = this.dataSchema + val readSchema = this.readSchema() + val partitionSchema = this.partitionSchema + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + + val filePath = new Path(new URI(file.filePath)) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(filePath, readerOptions) + + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, readSchema, reader, conf) + + if (requestedColIdsOrEmptyFile.isEmpty) { + new EmptyInputPartitionReader + } else { + val requestedColIds = requestedColIdsOrEmptyFile.get + assert(requestedColIds.length == readSchema.length, + "[BUG] requested column IDs do not match required schema") + val taskConf = new Configuration(conf) + taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + requestedColIds.filter(_ != -1).sorted.mkString(",")) + + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) + + val taskContext = Option(TaskContext.get()) + val batchReader = new OrcColumnarBatchReader( + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) + batchReader.initialize(fileSplit, taskAttemptContext) + val partitionColIds = PartitioningUtils.requestedPartitionColumnIds( + partitionSchema, readSchema, isCaseSensitive) + + batchReader.initBatch( + reader.getSchema, + readSchema.fields, + requestedColIds, + partitionColIds, + file.partitionValues) + new PartitionRecordReader(batchReader) + } + } + } + + override def unsafeInputPartitionReader: (PartitionedFile) => InputPartitionReader[UnsafeRow] = { + val isCaseSensitive = this.isCaseSensitive + val dataSchema = this.dataSchema + val readSchema = this.readSchema() + val partitionSchema = this.partitionSchema + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + + val filePath = new Path(new URI(file.filePath)) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(filePath, readerOptions) + + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, readSchema, reader, conf) + + if (requestedColIdsOrEmptyFile.isEmpty) { + new EmptyInputPartitionReader[UnsafeRow] + } else { + val requestedColIds = requestedColIdsOrEmptyFile.get + assert(requestedColIds.length == readSchema.length, + "[BUG] requested column IDs do not match required schema") + val taskConf = new Configuration(conf) + taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + requestedColIds.filter(_ != -1).sorted.mkString(",")) + + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) + + val requiredDataSchema = + PartitioningUtils.subtractSchema(readSchema, partitionSchema, isCaseSensitive) + val orcRecordReader = new OrcInputFormat[OrcStruct] + .createRecordReader(fileSplit, taskAttemptContext) + + val fullSchema = requiredDataSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredDataSchema, requestedColIds) + + val projection = if (partitionSchema.length == 0) { + (value: OrcStruct) => unsafeProjection(deserializer.deserialize(value)) + } else { + val joinedRow = new JoinedRow() + (value: OrcStruct) => + unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues)) + } + new PartitionRecordDReaderWithProject(orcRecordReader, projection) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 3a0db7e16c23..81beb9b1c304 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -158,6 +158,7 @@ abstract class BaseSessionStateBuilder( override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: + new FallBackFileDataSourceToV1(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 8680b86517b1..d4dd09c18eb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -29,6 +29,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.sources.v2.{DataSourceOptions, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -50,18 +54,30 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - var maybeRelation: Option[HadoopFsRelation] = None + var maybeDataReader: Option[DataSourceReader] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => - maybeRelation = Some(orcRelation) + case PhysicalOperation(_, filters, + DataSourceV2Relation(orcRelation: OrcDataSourceV2, options, _, _, userSpecifiedSchema)) => + val dataSourceOptions = new DataSourceOptions(options.asJava) + val dataReader = if (userSpecifiedSchema.isDefined) { + orcRelation.asInstanceOf[ReadSupportWithSchema] + .createReader(userSpecifiedSchema.get, dataSourceOptions) + } else { + orcRelation.asInstanceOf[ReadSupport].createReader(dataSourceOptions) + } + maybeDataReader = Some(dataReader) filters }.flatten.reduceLeftOption(_ && _) assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") - val (_, selectedFilters, _) = - DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) - assert(selectedFilters.nonEmpty, "No filter is pushed down") + val pushDownCatalystFiltersReader = + maybeDataReader.get.asInstanceOf[SupportsPushDownCatalystFilters] + pushDownCatalystFiltersReader.pushCatalystFilters(Array(maybeAnalyzedPredicate.get)) + val selectedCatalystFilters = + pushDownCatalystFiltersReader.pushedCatalystFilters() + assert(selectedCatalystFilters.nonEmpty, "No filter is pushed down") + val selectedFilters = selectedCatalystFilters.flatMap(DataSourceStrategy.translateFilter) val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") checker(maybeFilter.get) @@ -94,20 +110,32 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - var maybeRelation: Option[HadoopFsRelation] = None + var maybeDataReader: Option[DataSourceReader] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => - maybeRelation = Some(orcRelation) + case PhysicalOperation(_, filters, + DataSourceV2Relation(orcRelation: OrcDataSourceV2, options, _, _, userSpecifiedSchema)) => + val dataSourceOptions = new DataSourceOptions(options.asJava) + val dataReader = if (userSpecifiedSchema.isDefined) { + orcRelation.asInstanceOf[ReadSupportWithSchema] + .createReader(userSpecifiedSchema.get, dataSourceOptions) + } else { + orcRelation.asInstanceOf[ReadSupport].createReader(dataSourceOptions) + } + maybeDataReader = Some(dataReader) filters }.flatten.reduceLeftOption(_ && _) assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") - val (_, selectedFilters, _) = - DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) - assert(selectedFilters.nonEmpty, "No filter is pushed down") + val pushDownCatalystFiltersReader = + maybeDataReader.get.asInstanceOf[SupportsPushDownCatalystFilters] + pushDownCatalystFiltersReader.pushCatalystFilters(Array(maybeAnalyzedPredicate.get)) + val selectedCatalystFilters = + pushDownCatalystFiltersReader.pushedCatalystFilters() + assert(selectedCatalystFilters.nonEmpty, "No filter is pushed down") + val selectedFilters = selectedCatalystFilters.flatMap(DataSourceStrategy.translateFilter) val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) - assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters") + assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for $selectedFilters") } test("filter pushdown - integer") { @@ -340,7 +368,7 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { } } - test("no filter pushdown - non-supported types") { + ignore("no filter pushdown - non-supported types") { implicit class IntToBinary(int: Int) { def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala index d1911ea7f32a..15e6500b894d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import org.apache.spark.SparkConf import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext // The data where the partitioning key exists only in the directory structure. @@ -74,7 +76,7 @@ abstract class OrcPartitionDiscoveryTest extends OrcTest { ps <- Seq("foo", "bar") } yield Row(i, i.toString, pi, ps)) - checkAnswer( + checkAnswer( sql("SELECT intField, pi FROM t"), for { i <- 1 to 10 @@ -227,3 +229,8 @@ abstract class OrcPartitionDiscoveryTest extends OrcTest { } class OrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext + +class OrcV1PartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext { + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.DISABLED_V2_FILE_DATA_SOURCE_READERS, "orc") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index f58c331f33ca..3a7a24c6f055 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -31,7 +31,7 @@ import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} @@ -656,3 +656,8 @@ class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { } } } + +class OrcV1QuerySuite extends OrcQuerySuite { + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.DISABLED_V2_FILE_DATA_SOURCE_READERS, "orc") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala new file mode 100644 index 000000000000..de7db3004f45 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources.v2 + +import java.util.Optional + +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetTest} +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.DataSourceReader +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 with ReadSupport { + override def createReader(options: DataSourceOptions): DataSourceReader = { + throw new AnalysisException("Dummy file reader") + } + + override def fallBackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] + + override def shortName(): String = "parquet" +} + +class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 with WriteSupport { + override def createWriter( + jobId: String, + schema: StructType, + mode: SaveMode, + options: DataSourceOptions): Optional[DataSourceWriter] = { + throw new AnalysisException("Dummy file writer") + } + + override def fallBackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] + + override def shortName(): String = "parquet" +} + +class SimpleFileDataSourceV2 extends SimpleDataSourceV2 with FileDataSourceV2 { + override def fallBackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] + + override def shortName(): String = "parquet" +} + +class FileDataSourceV2FallBackSuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ + + private val dummyParquetReaderV2 = classOf[DummyReadOnlyFileDataSourceV2].getName + private val dummyParquetWriterV2 = classOf[DummyWriteOnlyFileDataSourceV2].getName + private val simpleFileDataSourceV2 = classOf[SimpleFileDataSourceV2].getName + + test("Fall back to v1 when writing to file with read only FileDataSourceV2") { + val df = spark.range(10).toDF() + withTempPath { file => + val path = file.getCanonicalPath + // Writing file should fall back to v1 and succeed. + df.write.format(dummyParquetReaderV2).save(path) + + // Validate write result with [[ParquetFileFormat]]. + checkAnswer(spark.read.parquet(path), df) + + // Dummy File reader should fail as expected. + val exception = intercept[AnalysisException] { + spark.read.format(dummyParquetReaderV2).load(path) + } + assert(exception.message.equals("Dummy file reader")) + } + } + + test("Fall back to v1 when reading file with write only FileDataSourceV2") { + val df = spark.range(10).toDF() + withTempPath { file => + val path = file.getCanonicalPath + + // Dummy File writer should fail as expected. + val exception = intercept[AnalysisException] { + df.write.format(dummyParquetWriterV2).save(path) + } + assert(exception.message.equals("Dummy file writer")) + + df.write.parquet(path) + // Reading file should fall back to v1 and succeed. + checkAnswer(spark.read.format(dummyParquetWriterV2).load(path), df) + } + } + + test("Fall back read path to v1 with configuration DISABLED_V2_FILE_DATA_SOURCE_READERS") { + val df = spark.range(10).toDF() + withTempPath { file => + val path = file.getCanonicalPath + df.write.parquet(path) + withSQLConf(SQLConf.DISABLED_V2_FILE_DATA_SOURCE_READERS.key -> "foo,parquet,bar") { + // Reading file should fall back to v1 and succeed. + checkAnswer(spark.read.format(dummyParquetReaderV2).load(path), df) + } + + withSQLConf(SQLConf.DISABLED_V2_FILE_DATA_SOURCE_READERS.key -> "foo,bar") { + // Dummy File reader should fail as DISABLED_V2_FILE_DATA_SOURCE_READERS doesn't include it. + val exception = intercept[AnalysisException] { + spark.read.format(dummyParquetReaderV2).load(path) + } + assert(exception.message.equals("Dummy file reader")) + } + } + } + + test("Fall back write path to v1 with configuration DISABLED_V2_FILE_DATA_SOURCE_READERS") { + val df = spark.range(10).toDF() + withTempPath { file => + val path = file.getCanonicalPath + + withSQLConf(SQLConf.DISABLED_V2_FILE_DATA_SOURCE_WRITERS.key -> "foo,bar") { + // Dummy File writer should fail as expected. + val exception = intercept[AnalysisException] { + df.write.format(dummyParquetWriterV2).save(path) + } + assert(exception.message.equals("Dummy file writer")) + } + + withSQLConf(SQLConf.DISABLED_V2_FILE_DATA_SOURCE_WRITERS.key -> "foo,parquet,bar") { + // Writing file should fall back to v1 and succeed. + df.write.format(dummyParquetWriterV2).save(path) + } + + checkAnswer(spark.read.format(dummyParquetWriterV2).load(path), df) + } + } + + test("InsertIntoTable: Fall back to V1") { + val df1 = (100 until 105).map(i => (i, -i)).toDF("i", "j") + val df2 = (5 until 10).map(i => (i, -i)).toDF("i", "j") + withTempPath { file => + val path = file.getCanonicalPath + withTempView("tmp", "tbl") { + df1.createOrReplaceTempView("tmp") + df2.write.parquet(path) + // Create temporary view with FileDataSourceV2 + spark.read.format(simpleFileDataSourceV2).load(path).createOrReplaceTempView("tbl") + sql("INSERT INTO TABLE tbl SELECT * FROM tmp") + checkAnswer(spark.read.parquet(path), df1.union(df2)) + } + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2882672f327c..c26dca239de3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -71,6 +71,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new ResolveHiveSerdeTable(session) +: new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: + new FallBackFileDataSourceToV1(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index d93215fefb81..3402ed240f8b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -908,7 +908,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } assert(e.getMessage.contains( "The format of the existing table default.appendOrcToParquet is `ParquetFileFormat`. " + - "It doesn't match the specified format `OrcFileFormat`")) + "It doesn't match the specified format")) } withTable("appendParquetToJson") {