Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,6 @@ class OrcFileFormat
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
OrcFilters.createFilter(dataSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames)
}
}

val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
val sqlConf = sparkSession.sessionState.conf
Expand All @@ -169,6 +164,8 @@ class OrcFileFormat
val broadcastedConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown
val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles

(file: PartitionedFile) => {
val conf = broadcastedConf.value.value
Expand All @@ -186,6 +183,15 @@ class OrcFileFormat
if (resultedColPruneInfo.isEmpty) {
Iterator.empty
} else {
// ORC predicate pushdown
if (orcFilterPushDown) {
OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
OrcFilters.createFilter(fileSchema, filters).foreach { f =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the input schema is only used to build the dataTypeMap, can we directly pass the physical ORC schema here? Then we don't need to convert the ORC schema to catalyst schema which is consistent with the parquet side.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our ORC pushdown code, e.g., buildLeafSearchArgument, uses catalyst schema in many places. I think it is still possible to use ORC's schema there and remove catalyst schema, but it looks like a big change. If we want to do it, I'd suggest to do it in another PR. This diff is already not small now.

OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
}
}
}

val (requestedColIds, canPruneCols) = resultedColPruneInfo.get
val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols,
dataSchema, resultSchema, partitionSchema, conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ trait OrcFiltersBase {
}
}

case class OrcPrimitiveField(fieldName: String, fieldType: DataType)

/**
* This method returns a map which contains ORC field name and data type. Each key
* represents a column; `dots` are used as separators for nested columns. If any part
Expand All @@ -49,19 +51,21 @@ trait OrcFiltersBase {
*/
protected[sql] def getSearchableTypeMap(
schema: StructType,
caseSensitive: Boolean): Map[String, DataType] = {
caseSensitive: Boolean): Map[String, OrcPrimitiveField] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper

def getPrimitiveFields(
fields: Seq[StructField],
parentFieldNames: Seq[String] = Seq.empty): Seq[(String, DataType)] = {
parentFieldNames: Seq[String] = Seq.empty): Seq[(String, OrcPrimitiveField)] = {
fields.flatMap { f =>
f.dataType match {
case st: StructType =>
getPrimitiveFields(st.fields, parentFieldNames :+ f.name)
case BinaryType => None
case _: AtomicType =>
Some(((parentFieldNames :+ f.name).quoted, f.dataType))
val fieldName = (parentFieldNames :+ f.name).quoted
val orcField = OrcPrimitiveField(fieldName, f.dataType)
Some((fieldName, orcField))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the OrcPrimitiveField if fieldName is always orcField.fieldName here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we look up attribute name of pushed down predicates in the map dataTypeMap(name), we need to easily retrieve the file schema's field name. If we simply use Map[String, DataType], the matched attribute name under case-insensitive analysis could be possibly in different letter case to the file schema's field name.

E,g, If ORC file has 'ABCfield, pushed down predicate is "abc > 0",dataTypeMap.contain("abc")is true and we can getdataTypeMap("abc")` back. But we need file schema's field name "ABC".

It is similar to Parquet:

private val nameToParquetField : Map[String, ParquetPrimitiveField] = {
// Recursively traverse the parquet schema to get primitive fields that can be pushed-down.
// `parentFieldNames` is used to keep track of the current nested level when traversing.
def getPrimitiveFields(
fields: Seq[Type],
parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = {
fields.flatMap {
case p: PrimitiveType =>
Some(ParquetPrimitiveField(fieldNames = parentFieldNames :+ p.getName,
fieldType = ParquetSchemaType(p.getOriginalType,
p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata)))
// Note that when g is a `Struct`, `g.getOriginalType` is `null`.
// When g is a `Map`, `g.getOriginalType` is `MAP`.
// When g is a `List`, `g.getOriginalType` is `LIST`.
case g: GroupType if g.getOriginalType == null =>
getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName)
// Parquet only supports push-down for primitive types; as a result, Map and List types
// are removed.
case _ => None
}
}

case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ object OrcUtils extends Logging {
}
}

def readCatalystSchema(
file: Path,
conf: Configuration,
ignoreCorruptFiles: Boolean): Option[StructType] = {
readSchema(file, conf, ignoreCorruptFiles) match {
case Some(schema) =>
Some(CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType])

case None =>
// Field names is empty or `FileFormatException` was thrown but ignoreCorruptFiles is true.
None
}
}

/**
* Reads ORC file schemas in multi-threaded manner, using native version of ORC.
* This is visible for testing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcUtils}
import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils}
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{AtomicType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand All @@ -52,24 +53,39 @@ case class OrcPartitionReaderFactory(
broadcastedConf: Broadcast[SerializableConfiguration],
dataSchema: StructType,
readDataSchema: StructType,
partitionSchema: StructType) extends FilePartitionReaderFactory {
partitionSchema: StructType,
filters: Array[Filter]) extends FilePartitionReaderFactory {
private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields)
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val capacity = sqlConf.orcVectorizedReaderBatchSize
private val orcFilterPushDown = sqlConf.orcFilterPushDown
private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles

override def supportColumnarReads(partition: InputPartition): Boolean = {
sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
}

private def pushDownPredicates(filePath: Path, conf: Configuration): Unit = {
if (orcFilterPushDown) {
OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
OrcFilters.createFilter(fileSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
}
}
}
}

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value

OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)

val filePath = new Path(new URI(file.filePath))

pushDownPredicates(filePath, conf)

val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
Expand Down Expand Up @@ -116,6 +132,8 @@ case class OrcPartitionReaderFactory(

val filePath = new Path(new URI(file.filePath))

pushDownPredicates(filePath, conf)

val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ case class OrcScan(
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema)
dataSchema, readDataSchema, readPartitionSchema, pushedFilters)
}

override def equals(obj: Any): Boolean = obj match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ case class OrcScanBuilder(

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
OrcFilters.createFilter(schema, filters).foreach { f =>
// The pushed filters will be set in `hadoopConf`. After that, we can simply use the
// changed `hadoopConf` in executors.
OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames)
}
val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {

def convertibleFilters(
schema: StructType,
dataTypeMap: Map[String, DataType],
dataTypeMap: Map[String, OrcPrimitiveField],
filters: Seq[Filter]): Seq[Filter] = {
import org.apache.spark.sql.sources._

Expand Down Expand Up @@ -179,7 +179,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildSearchArgument(
dataTypeMap: Map[String, DataType],
dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Builder = {
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -215,7 +215,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildLeafSearchArgument(
dataTypeMap: Map[String, DataType],
dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Option[Builder] = {
def getType(attribute: String): PredicateLeaf.Type =
Expand All @@ -228,38 +228,44 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
expression match {
case EqualTo(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().equals(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.equals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case EqualNullSafe(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case LessThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThan(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case LessThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case GreaterThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startNot()
.lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThan(name, getType(name), castedValue).end())
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startNot()
.lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case IsNull(name) if dataTypeMap.contains(name) =>
Some(builder.startAnd().isNull(name, getType(name)).end())
Some(builder.startAnd().isNull(dataTypeMap(name).fieldName, getType(name)).end())

case IsNotNull(name) if dataTypeMap.contains(name) =>
Some(builder.startNot().isNull(name, getType(name)).end())
Some(builder.startNot().isNull(dataTypeMap(name).fieldName, getType(name)).end())

case In(name, values) if dataTypeMap.contains(name) =>
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name)))
Some(builder.startAnd().in(name, getType(name),
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType))
Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())

case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._

import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument}
import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row}
Expand Down Expand Up @@ -586,8 +587,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1)))

val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0"))
// TODO: ORC predicate pushdown should work under case-insensitive analysis.
// assert(actual.count() == 1)
assert(actual.count() == 1)
}
}

Expand All @@ -606,5 +606,71 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
}
}
}

test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") {
import org.apache.spark.sql.sources._

def getOrcFilter(
schema: StructType,
filters: Seq[Filter],
caseSensitive: String): Option[SearchArgument] = {
var orcFilter: Option[SearchArgument] = None
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
orcFilter =
OrcFilters.createFilter(schema, filters)
}
orcFilter
}

def testFilter(
schema: StructType,
filters: Seq[Filter],
expected: SearchArgument): Unit = {
val caseSensitiveFilters = getOrcFilter(schema, filters, "true")
val caseInsensitiveFilters = getOrcFilter(schema, filters, "false")

assert(caseSensitiveFilters.isEmpty)
assert(caseInsensitiveFilters.isDefined)

assert(caseInsensitiveFilters.get.getLeaves().size() > 0)
assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size())
(0 until expected.getLeaves().size()).foreach { index =>
assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index))
}
}

val schema1 = StructType(Seq(StructField("cint", IntegerType)))
testFilter(schema1, Seq(GreaterThan("CINT", 1)),
newBuilder.startNot()
.lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema1, Seq(
And(GreaterThan("CINT", 1), EqualTo("Cint", 2))),
newBuilder.startAnd()
.startNot()
.lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`()
.equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L)
.`end`().build())

// Nested column case
val schema2 = StructType(Seq(StructField("a",
StructType(Seq(StructField("cint", IntegerType))))))

testFilter(schema2, Seq(GreaterThan("A.CINT", 1)),
newBuilder.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema2, Seq(GreaterThan("a.CINT", 1)),
newBuilder.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema2, Seq(GreaterThan("A.cint", 1)),
newBuilder.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
testFilter(schema2, Seq(
And(GreaterThan("a.CINT", 1), EqualTo("a.Cint", 2))),
newBuilder.startAnd()
.startNot()
.lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`()
.equals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 2L)
.`end`().build())
}
}

Loading