Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,6 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory {
def buildColumnarReader(partitionedFile: PartitionedFile): PartitionReader[ColumnarBatch] = {
throw new UnsupportedOperationException("Cannot create columnar reader.")
}

protected def getReadDataSchema(
readSchema: StructType,
partitionSchema: StructType,
isCaseSensitive: Boolean): StructType = {
val partitionNameSet =
partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
val fields = readSchema.fields.filterNot { field =>
partitionNameSet.contains(PartitioningUtils.getColName(field, isCaseSensitive))
}

StructType(fields)
}
}

// A compound class for combining file and its corresponding reader.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
*/
package org.apache.spark.sql.execution.datasources.v2

import java.util.Locale

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
Expand All @@ -28,8 +32,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
abstract class FileScan(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
readSchema: StructType,
options: CaseInsensitiveStringMap) extends Scan with Batch {
readDataSchema: StructType,
readPartitionSchema: StructType) extends Scan with Batch {
/**
* Returns whether a file with `path` could be split or not.
*/
Expand All @@ -40,7 +44,23 @@ abstract class FileScan(
protected def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
val partitionAttributes = fileIndex.partitionSchema.toAttributes
val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap
val readPartitionAttributes = readPartitionSchema.map { readField =>
attributeMap.get(normalizeName(readField.name)).getOrElse {
throw new AnalysisException(s"Can't find required partition column ${readField.name} " +
s"in partition schema ${fileIndex.partitionSchema}")
}
}
lazy val partitionValueProject =
GenerateUnsafeProjection.generate(readPartitionAttributes, partitionAttributes)
val splitFiles = selectedPartitions.flatMap { partition =>
// Prune partition values if part of the partition columns are not required.
val partitionValues = if (readPartitionAttributes != partitionAttributes) {
partitionValueProject(partition.values).copy()
} else {
partition.values
}
partition.files.flatMap { file =>
val filePath = file.getPath
PartitionedFileUtil.splitFiles(
Expand All @@ -49,7 +69,7 @@ abstract class FileScan(
filePath = filePath,
isSplitable = isSplitable(filePath),
maxSplitBytes = maxSplitBytes,
partitionValues = partition.values
partitionValues = partitionValues
)
}.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
}
Expand All @@ -61,4 +81,17 @@ abstract class FileScan(
}

override def toBatch: Batch = this

override def readSchema(): StructType =
StructType(readDataSchema.fields ++ readPartitionSchema.fields)

private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis

private def normalizeName(name: String): String = {
if (isCaseSensitive) {
name
} else {
name.toLowerCase(Locale.ROOT)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,44 @@
*/
package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.sources.v2.reader.{ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils}
import org.apache.spark.sql.sources.v2.reader.{ScanBuilder, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.types.StructType

abstract class FileScanBuilder(schema: StructType)
extends ScanBuilder
with SupportsPushDownRequiredColumns {
protected var readSchema = schema
abstract class FileScanBuilder(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns {
private val partitionSchema = fileIndex.partitionSchema
private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields)

override def pruneColumns(requiredSchema: StructType): Unit = {
this.readSchema = requiredSchema
this.requiredSchema = requiredSchema
}

protected def readDataSchema(): StructType = {
val requiredNameSet = createRequiredNameSet()
val fields = dataSchema.fields.filter { field =>
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
requiredNameSet.contains(colName) && !partitionNameSet.contains(colName)
}
StructType(fields)
}

protected def readPartitionSchema(): StructType = {
val requiredNameSet = createRequiredNameSet()
val fields = partitionSchema.fields.filter { field =>
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
requiredNameSet.contains(colName)
}
StructType(fields)
}

private def createRequiredNameSet(): Set[String] =
requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet

private val partitionNameSet: Set[String] =
partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
abstract class TextBasedFileScan(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
readSchema: StructType,
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScan(sparkSession, fileIndex, readSchema, options) {
extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) {
private var codecFactory: CompressionCodecFactory = _

override def isSplitable(path: Path): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,21 @@ import org.apache.spark.util.SerializableConfiguration
* @param sqlConf SQL configuration.
* @param broadcastedConf Broadcasted serializable Hadoop Configuration.
* @param dataSchema Schema of CSV files.
* @param readDataSchema Required data schema in the batch scan.
* @param partitionSchema Schema of partitions.
* @param readSchema Required schema in the batch scan.
* @param parsedOptions Options for parsing CSV files.
*/
case class CSVPartitionReaderFactory(
sqlConf: SQLConf,
broadcastedConf: Broadcast[SerializableConfiguration],
dataSchema: StructType,
readDataSchema: StructType,
partitionSchema: StructType,
readSchema: StructType,
parsedOptions: CSVOptions) extends FilePartitionReaderFactory {
private val columnPruning = sqlConf.csvColumnPruning
private val readDataSchema =
getReadDataSchema(readSchema, partitionSchema, sqlConf.caseSensitiveAnalysis)

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value

val parser = new UnivocityParser(
StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
StructType(readDataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ case class CSVScan(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType,
readSchema: StructType,
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap)
extends TextBasedFileScan(sparkSession, fileIndex, readSchema, options) {
extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) {

private lazy val parsedOptions: CSVOptions = new CSVOptions(
options.asScala.toMap,
Expand All @@ -53,8 +54,8 @@ case class CSVScan(
// Check a field requirement for corrupt records here to throw an exception in a driver side
ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord)

if (readSchema.length == 1 &&
readSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
if (readDataSchema.length == 1 &&
readDataSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
throw new AnalysisException(
"Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" +
"referenced columns only include the internal corrupt record column\n" +
Expand All @@ -72,7 +73,9 @@ case class CSVScan(
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
val broadcastedConf = sparkSession.sparkContext.broadcast(
new SerializableConfiguration(hadoopConf))
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, fileIndex.partitionSchema, readSchema, parsedOptions)
dataSchema, readDataSchema, readPartitionSchema, parsedOptions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ case class CSVScanBuilder(
fileIndex: PartitioningAwareFileIndex,
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) {
options: CaseInsensitiveStringMap)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {

override def build(): Scan = {
CSVScan(sparkSession, fileIndex, dataSchema, readSchema, options)
CSVScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,30 @@ import org.apache.spark.util.SerializableConfiguration
* @param sqlConf SQL configuration.
* @param broadcastedConf Broadcast serializable Hadoop Configuration.
* @param dataSchema Schema of orc files.
* @param readDataSchema Required data schema in the batch scan.
* @param partitionSchema Schema of partitions.
* @param readSchema Required schema in the batch scan.
*/
case class OrcPartitionReaderFactory(
sqlConf: SQLConf,
broadcastedConf: Broadcast[SerializableConfiguration],
dataSchema: StructType,
partitionSchema: StructType,
readSchema: StructType) extends FilePartitionReaderFactory {
readDataSchema: StructType,
partitionSchema: StructType) extends FilePartitionReaderFactory {
private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields)
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val capacity = sqlConf.orcVectorizedReaderBatchSize

override def supportColumnarReads(partition: InputPartition): Boolean = {
sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
readSchema.length <= sqlConf.wholeStageMaxNumFields &&
readSchema.forall(_.dataType.isInstanceOf[AtomicType])
resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
}

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value

val readDataSchema = getReadDataSchema(readSchema, partitionSchema, isCaseSensitive)
val readDataSchemaString = OrcUtils.orcTypeDescriptionString(readDataSchema)
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, readDataSchemaString)
val resultSchemaString = OrcUtils.orcTypeDescriptionString(resultSchema)
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString)
OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)

val filePath = new Path(new URI(file.filePath))
Expand Down Expand Up @@ -113,8 +113,8 @@ case class OrcPartitionReaderFactory(
override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
val conf = broadcastedConf.value.value

val readSchemaString = OrcUtils.orcTypeDescriptionString(readSchema)
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, readSchemaString)
val resultSchemaString = OrcUtils.orcTypeDescriptionString(resultSchema)
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString)
OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)

val filePath = new Path(new URI(file.filePath))
Expand All @@ -124,13 +124,13 @@ case class OrcPartitionReaderFactory(
val reader = OrcFile.createReader(filePath, readerOptions)

val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds(
isCaseSensitive, dataSchema, readSchema, reader, conf)
isCaseSensitive, dataSchema, readDataSchema, reader, conf)

if (requestedColIdsOrEmptyFile.isEmpty) {
new EmptyPartitionReader
} else {
val requestedColIds = requestedColIdsOrEmptyFile.get
assert(requestedColIds.length == readSchema.length,
val requestedColIds = requestedColIdsOrEmptyFile.get ++ Array.fill(partitionSchema.length)(-1)
assert(requestedColIds.length == resultSchema.length,
"[BUG] requested column IDs do not match required schema")
val taskConf = new Configuration(conf)

Expand All @@ -140,15 +140,12 @@ case class OrcPartitionReaderFactory(

val batchReader = new OrcColumnarBatchReader(capacity)
batchReader.initialize(fileSplit, taskAttemptContext)
val columnNameMap = partitionSchema.fields.map(
PartitioningUtils.getColName(_, isCaseSensitive)).zipWithIndex.toMap
val requestedPartitionColIds = readSchema.fields.map { field =>
columnNameMap.getOrElse(PartitioningUtils.getColName(field, isCaseSensitive), -1)
}
val requestedPartitionColIds =
Array.fill(readDataSchema.length)(-1) ++ Range(0, partitionSchema.length)

batchReader.initBatch(
TypeDescription.fromString(readSchemaString),
readSchema.fields,
TypeDescription.fromString(resultSchemaString),
resultSchema.fields,
requestedColIds,
requestedPartitionColIds,
file.partitionValues)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,18 @@ case class OrcScan(
hadoopConf: Configuration,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType,
readSchema: StructType,
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScan(sparkSession, fileIndex, readSchema, options) {
extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) {
override def isSplitable(path: Path): Boolean = true

override def createReaderFactory(): PartitionReaderFactory = {
val broadcastedConf = sparkSession.sparkContext.broadcast(
new SerializableConfiguration(hadoopConf))
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, fileIndex.partitionSchema, readSchema)
dataSchema, readDataSchema, readPartitionSchema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ case class OrcScanBuilder(
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScanBuilder(schema) with SupportsPushDownFilters {
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters {
lazy val hadoopConf = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
}

override def build(): Scan = {
OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readSchema, options)
OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema,
readDataSchema(), readPartitionSchema(), options)
}

private var _pushedFilters: Array[Filter] = Array.empty
Expand Down