Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ class OrcFileFormat
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
OrcFilters.createFilter(dataSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames)
}
}

val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
val sqlConf = sparkSession.sessionState.conf
Expand All @@ -164,8 +169,6 @@ class OrcFileFormat
val broadcastedConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown
val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles

(file: PartitionedFile) => {
val conf = broadcastedConf.value.value
Expand All @@ -183,15 +186,6 @@ class OrcFileFormat
if (resultedColPruneInfo.isEmpty) {
Iterator.empty
} else {
// ORC predicate pushdown
if (orcFilterPushDown) {
OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
OrcFilters.createFilter(fileSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
}
}
}

val (requestedColIds, canPruneCols) = resultedColPruneInfo.get
val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols,
dataSchema, resultSchema, partitionSchema, conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ trait OrcFiltersBase {
}
}

case class OrcPrimitiveField(fieldName: String, fieldType: DataType)

/**
* This method returns a map which contains ORC field name and data type. Each key
* represents a column; `dots` are used as separators for nested columns. If any part
Expand All @@ -51,21 +49,19 @@ trait OrcFiltersBase {
*/
protected[sql] def getSearchableTypeMap(
schema: StructType,
caseSensitive: Boolean): Map[String, OrcPrimitiveField] = {
caseSensitive: Boolean): Map[String, DataType] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper

def getPrimitiveFields(
fields: Seq[StructField],
parentFieldNames: Seq[String] = Seq.empty): Seq[(String, OrcPrimitiveField)] = {
parentFieldNames: Seq[String] = Seq.empty): Seq[(String, DataType)] = {
fields.flatMap { f =>
f.dataType match {
case st: StructType =>
getPrimitiveFields(st.fields, parentFieldNames :+ f.name)
case BinaryType => None
case _: AtomicType =>
val fieldName = (parentFieldNames :+ f.name).quoted
val orcField = OrcPrimitiveField(fieldName, f.dataType)
Some((fieldName, orcField))
Some(((parentFieldNames :+ f.name).quoted, f.dataType))
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,6 @@ object OrcUtils extends Logging {
}
}

def readCatalystSchema(
file: Path,
conf: Configuration,
ignoreCorruptFiles: Boolean): Option[StructType] = {
readSchema(file, conf, ignoreCorruptFiles) match {
case Some(schema) =>
Some(CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType])

case None =>
// Field names is empty or `FileFormatException` was thrown but ignoreCorruptFiles is true.
None
}
}

/**
* Reads ORC file schemas in multi-threaded manner, using native version of ORC.
* This is visible for testing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils}
import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcUtils}
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{AtomicType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand All @@ -53,39 +52,24 @@ case class OrcPartitionReaderFactory(
broadcastedConf: Broadcast[SerializableConfiguration],
dataSchema: StructType,
readDataSchema: StructType,
partitionSchema: StructType,
filters: Array[Filter]) extends FilePartitionReaderFactory {
partitionSchema: StructType) extends FilePartitionReaderFactory {
private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields)
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val capacity = sqlConf.orcVectorizedReaderBatchSize
private val orcFilterPushDown = sqlConf.orcFilterPushDown
private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles

override def supportColumnarReads(partition: InputPartition): Boolean = {
sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
}

private def pushDownPredicates(filePath: Path, conf: Configuration): Unit = {
if (orcFilterPushDown) {
OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
OrcFilters.createFilter(fileSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
}
}
}
}

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value

OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)

val filePath = new Path(new URI(file.filePath))

pushDownPredicates(filePath, conf)

val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
Expand Down Expand Up @@ -132,8 +116,6 @@ case class OrcPartitionReaderFactory(

val filePath = new Path(new URI(file.filePath))

pushDownPredicates(filePath, conf)

val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ case class OrcScan(
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema, pushedFilters)
dataSchema, readDataSchema, readPartitionSchema)
}

override def equals(obj: Any): Boolean = obj match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ case class OrcScanBuilder(

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
OrcFilters.createFilter(schema, filters).foreach { f =>
// The pushed filters will be set in `hadoopConf`. After that, we can simply use the
// changed `hadoopConf` in executors.
OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames)
}
val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {

def convertibleFilters(
schema: StructType,
dataTypeMap: Map[String, OrcPrimitiveField],
dataTypeMap: Map[String, DataType],
filters: Seq[Filter]): Seq[Filter] = {
import org.apache.spark.sql.sources._

Expand Down Expand Up @@ -179,7 +179,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildSearchArgument(
dataTypeMap: Map[String, OrcPrimitiveField],
dataTypeMap: Map[String, DataType],
expression: Filter,
builder: Builder): Builder = {
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -215,7 +215,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildLeafSearchArgument(
dataTypeMap: Map[String, OrcPrimitiveField],
dataTypeMap: Map[String, DataType],
expression: Filter,
builder: Builder): Option[Builder] = {
def getType(attribute: String): PredicateLeaf.Type =
Expand All @@ -228,44 +228,38 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
expression match {
case EqualTo(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.equals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().equals(name, getType(name), castedValue).end())

case EqualNullSafe(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end())

case LessThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThan(name, getType(name), castedValue).end())

case LessThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end())

case GreaterThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startNot()
.lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end())

case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startNot()
.lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThan(name, getType(name), castedValue).end())

case IsNull(name) if dataTypeMap.contains(name) =>
Some(builder.startAnd().isNull(dataTypeMap(name).fieldName, getType(name)).end())
Some(builder.startAnd().isNull(name, getType(name)).end())

case IsNotNull(name) if dataTypeMap.contains(name) =>
Some(builder.startNot().isNull(dataTypeMap(name).fieldName, getType(name)).end())
Some(builder.startNot().isNull(name, getType(name)).end())

case In(name, values) if dataTypeMap.contains(name) =>
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType))
Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name),
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name)))
Some(builder.startAnd().in(name, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())

case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._

import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument}
import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row}
Expand Down Expand Up @@ -587,7 +586,8 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1)))

val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0"))
assert(actual.count() == 1)
// TODO: ORC predicate pushdown should work under case-insensitive analysis.
// assert(actual.count() == 1)
}
}

Expand All @@ -606,71 +606,5 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
}
}
}

test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") {
import org.apache.spark.sql.sources._

def getOrcFilter(
schema: StructType,
filters: Seq[Filter],
caseSensitive: String): Option[SearchArgument] = {
var orcFilter: Option[SearchArgument] = None
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
orcFilter =
OrcFilters.createFilter(schema, filters)
}
orcFilter
}

def testFilter(
schema: StructType,
filters: Seq[Filter],
expected: SearchArgument): Unit = {
val caseSensitiveFilters = getOrcFilter(schema, filters, "true")
val caseInsensitiveFilters = getOrcFilter(schema, filters, "false")

assert(caseSensitiveFilters.isEmpty)
assert(caseInsensitiveFilters.isDefined)

assert(caseInsensitiveFilters.get.getLeaves().size() > 0)
assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size())
(0 until expected.getLeaves().size()).foreach { index =>
assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index))
}
}

val schema1 = StructType(Seq(StructField("cint", IntegerType)))
testFilter(schema1, Seq(GreaterThan("CINT", 1)),
newBuilder.startNot()
.lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema1, Seq(
And(GreaterThan("CINT", 1), EqualTo("Cint", 2))),
newBuilder.startAnd()
.startNot()
.lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`()
.equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L)
.`end`().build())

// Nested column case
val schema2 = StructType(Seq(StructField("a",
StructType(Seq(StructField("cint", IntegerType))))))

testFilter(schema2, Seq(GreaterThan("A.CINT", 1)),
newBuilder.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema2, Seq(GreaterThan("a.CINT", 1)),
newBuilder.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema2, Seq(GreaterThan("A.cint", 1)),
newBuilder.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema2, Seq(
And(GreaterThan("a.CINT", 1), EqualTo("a.Cint", 2))),
newBuilder.startAnd()
.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`()
.equals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 2L)
.`end`().build())
}
}

Loading