diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index f7f68b1eb90d..b015251b1dc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration import org.apache.spark.{Partition => RDDPartition, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil @@ -25,6 +27,7 @@ import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.util.ThreadUtils /** * A single file that should be read, along with partition column values that @@ -50,12 +53,28 @@ case class PartitionedFile( */ case class FilePartition(index: Int, files: Seq[PartitionedFile]) extends RDDPartition +object FileScanRDD { + private val ioExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("FileScanRDD", 16)) +} + class FileScanRDD( @transient private val sparkSession: SparkSession, readFunction: (PartitionedFile) => Iterator[InternalRow], @transient val filePartitions: Seq[FilePartition]) extends RDD[InternalRow](sparkSession.sparkContext, Nil) { + /** + * To get better interleaving of CPU and IO, this RDD will create a future to prepare the next + * file while the current one is being processed. `currentIterator` is the current file and + * `nextFile` is the future that will initialize the next file to be read. This includes things + * such as starting up connections to open the file and any initial buffering. The expectation + * is that `currentIterator` is CPU intensive and `nextFile` is IO intensive. + */ + val isAsyncIOEnabled = sparkSession.sessionState.conf.filesAsyncIO + + case class NextFile(file: PartitionedFile, iter: Iterator[Object]) + override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { private val inputMetrics = context.taskMetrics().inputMetrics @@ -88,6 +107,9 @@ class FileScanRDD( private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null + private[this] var nextFile: Future[NextFile] = + if (isAsyncIOEnabled) prepareNextFile() else null + def hasNext = (currentIterator != null && currentIterator.hasNext) || nextIterator() def next() = { val nextElement = currentIterator.next() @@ -107,16 +129,32 @@ class FileScanRDD( /** Advances to the next file. Returns true if a new non-empty iterator is available. */ private def nextIterator(): Boolean = { updateBytesReadWithFileSize() - if (files.hasNext) { - currentFile = files.next() - logInfo(s"Reading File $currentFile") - InputFileNameHolder.setInputFileName(currentFile.filePath) - currentIterator = readFunction(currentFile) - hasNext + if (isAsyncIOEnabled) { + if (nextFile != null) { + // Wait for the async task to complete + val file = ThreadUtils.awaitResult(nextFile, Duration.Inf) + InputFileNameHolder.setInputFileName(file.file.filePath) + currentIterator = file.iter + // Asynchronously start the next file. + nextFile = prepareNextFile() + hasNext + } else { + currentFile = null + InputFileNameHolder.unsetInputFileName() + false + } } else { - currentFile = null - InputFileNameHolder.unsetInputFileName() - false + if (files.hasNext) { + currentFile = files.next() + logInfo(s"Reading File $currentFile") + InputFileNameHolder.setInputFileName(currentFile.filePath) + currentIterator = readFunction(currentFile) + hasNext + } else { + currentFile = null + InputFileNameHolder.unsetInputFileName() + false + } } } @@ -125,6 +163,20 @@ class FileScanRDD( updateBytesReadWithFileSize() InputFileNameHolder.unsetInputFileName() } + + def prepareNextFile(): Future[NextFile] = { + if (files.hasNext) { + Future { + val nextFile = files.next() + val nextFileIter = readFunction(nextFile) + // Read something from the file to trigger some initial IO. + nextFileIter.hasNext + NextFile(nextFile, nextFileIter) + }(FileScanRDD.ioExecutionContext) + } else { + null + } + } } // Register an on-task-completion callback to close the input stream. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6fbf32676f5a..70d5eee4928d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -438,6 +438,12 @@ object SQLConf { .longConf .createWithDefault(4 * 1024 * 1024) + val FILES_ASYNC_IO = SQLConfigBuilder("spark.sql.files.asyncIO") + .internal() + .doc("If true, attempts to asynchronously do IO when reading data.") + .booleanConf + .createWithDefault(true) + val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse") .internal() .doc("When true, the planner will try to find out duplicated exchanges and re-use them.") @@ -547,6 +553,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) + def filesAsyncIO: Boolean = getConf(FILES_ASYNC_IO) + def useCompression: Boolean = getConf(COMPRESS_CACHED) def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 32fe5ba127ca..b11f3e156112 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -78,10 +78,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("basic data types (without binary)") { - val data = (1 to 4).map { i => - (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + (true :: false :: Nil).foreach { v => + withSQLConf(SQLConf.FILES_ASYNC_IO.key -> v.toString) { + val data = (1 to 4).map { i => + (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + } + checkParquetFile(data) + } } - checkParquetFile(data) } test("raw binary") {