diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 4dff1ec7ebfb..69badb4f7d59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -153,11 +153,6 @@ class OrcFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - if (sparkSession.sessionState.conf.orcFilterPushDown) { - OrcFilters.createFilter(dataSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) - } - } val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val sqlConf = sparkSession.sessionState.conf @@ -169,6 +164,8 @@ class OrcFileFormat val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedConf.value.value @@ -186,6 +183,15 @@ class OrcFileFormat if (resultedColPruneInfo.isEmpty) { Iterator.empty } else { + // ORC predicate pushdown + if (orcFilterPushDown) { + OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema => + OrcFilters.createFilter(fileSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) + } + } + } + val (requestedColIds, canPruneCols) = resultedColPruneInfo.get val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols, dataSchema, resultSchema, partitionSchema, conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index b277b4da1cf8..ee0c08dd939a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -39,6 +39,8 @@ trait OrcFiltersBase { } } + case class OrcPrimitiveField(fieldName: String, fieldType: DataType) + /** * This method returns a map which contains ORC field name and data type. Each key * represents a column; `dots` are used as separators for nested columns. If any part @@ -49,19 +51,21 @@ trait OrcFiltersBase { */ protected[sql] def getSearchableTypeMap( schema: StructType, - caseSensitive: Boolean): Map[String, DataType] = { + caseSensitive: Boolean): Map[String, OrcPrimitiveField] = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper def getPrimitiveFields( fields: Seq[StructField], - parentFieldNames: Seq[String] = Seq.empty): Seq[(String, DataType)] = { + parentFieldNames: Seq[String] = Seq.empty): Seq[(String, OrcPrimitiveField)] = { fields.flatMap { f => f.dataType match { case st: StructType => getPrimitiveFields(st.fields, parentFieldNames :+ f.name) case BinaryType => None case _: AtomicType => - Some(((parentFieldNames :+ f.name).quoted, f.dataType)) + val fieldName = (parentFieldNames :+ f.name).quoted + val orcField = OrcPrimitiveField(fieldName, f.dataType) + Some((fieldName, orcField)) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 072e670081d1..264cf8165e13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -92,6 +92,20 @@ object OrcUtils extends Logging { } } + def readCatalystSchema( + file: Path, + conf: Configuration, + ignoreCorruptFiles: Boolean): Option[StructType] = { + readSchema(file, conf, ignoreCorruptFiles) match { + case Some(schema) => + Some(CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType]) + + case None => + // Field names is empty or `FileFormatException` was thrown but ignoreCorruptFiles is true. + None + } + } + /** * Reads ORC file schemas in multi-threaded manner, using native version of ORC. * This is visible for testing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 7f25f7bd135f..1f38128e98fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -31,9 +31,10 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.PartitionedFile -import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcUtils} +import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -52,10 +53,13 @@ case class OrcPartitionReaderFactory( broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, - partitionSchema: StructType) extends FilePartitionReaderFactory { + partitionSchema: StructType, + filters: Array[Filter]) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize + private val orcFilterPushDown = sqlConf.orcFilterPushDown + private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled && @@ -63,6 +67,16 @@ case class OrcPartitionReaderFactory( resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) } + private def pushDownPredicates(filePath: Path, conf: Configuration): Unit = { + if (orcFilterPushDown) { + OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema => + OrcFilters.createFilter(fileSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) + } + } + } + } + override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value @@ -70,6 +84,8 @@ case class OrcPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) + pushDownPredicates(filePath, conf) + val fs = filePath.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = @@ -116,6 +132,8 @@ case class OrcPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) + pushDownPredicates(filePath, conf) + val fs = filePath.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 38b8ced51a14..1710abed57b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -48,7 +48,7 @@ case class OrcScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema) + dataSchema, readDataSchema, readPartitionSchema, pushedFilters) } override def equals(obj: Any): Boolean = obj match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 0330dacffa58..2f9387532c25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -56,11 +56,6 @@ case class OrcScanBuilder( override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { - OrcFilters.createFilter(schema, filters).foreach { f => - // The pushed filters will be set in `hadoopConf`. After that, we can simply use the - // changed `hadoopConf` in executors. - OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames) - } val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray } diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index bc11bb8c1d5d..0e657bfe6623 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -81,7 +81,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], filters: Seq[Filter]): Seq[Filter] = { import org.apache.spark.sql.sources._ @@ -139,7 +139,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Get PredicateLeafType which is corresponding to the given DataType. */ - private def getPredicateLeafType(dataType: DataType) = dataType match { + def getPredicateLeafType(dataType: DataType): PredicateLeaf.Type = dataType match { case BooleanType => PredicateLeaf.Type.BOOLEAN case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG case FloatType | DoubleType => PredicateLeaf.Type.FLOAT @@ -179,7 +179,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], expression: Filter, builder: Builder): Builder = { import org.apache.spark.sql.sources._ @@ -215,11 +215,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], expression: Filter, builder: Builder): Option[Builder] = { def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute)) + getPredicateLeafType(dataTypeMap(attribute).fieldType) import org.apache.spark.sql.sources._ @@ -228,38 +228,44 @@ private[sql] object OrcFilters extends OrcFiltersBase { // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). expression match { case EqualTo(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().equals(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case EqualNullSafe(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case LessThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case GreaterThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startNot() + .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startNot() + .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case IsNull(name) if dataTypeMap.contains(name) => - Some(builder.startAnd().isNull(name, getType(name)).end()) + Some(builder.startAnd().isNull(dataTypeMap(name).fieldName, getType(name)).end()) case IsNotNull(name) if dataTypeMap.contains(name) => - Some(builder.startNot().isNull(name, getType(name)).end()) + Some(builder.startNot().isNull(dataTypeMap(name).fieldName, getType(name)).end()) case In(name, values) if dataTypeMap.contains(name) => - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) - Some(builder.startAnd().in(name, getType(name), + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType)) + Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index dfb3595be9ad..e159a0588dff 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} +import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row} @@ -586,8 +587,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1))) val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0")) - // TODO: ORC predicate pushdown should work under case-insensitive analysis. - // assert(actual.count() == 1) + assert(actual.count() == 1) } } @@ -606,5 +606,71 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } } + + test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") { + import org.apache.spark.sql.sources._ + + def getOrcFilter( + schema: StructType, + filters: Seq[Filter], + caseSensitive: String): Option[SearchArgument] = { + var orcFilter: Option[SearchArgument] = None + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + orcFilter = + OrcFilters.createFilter(schema, filters) + } + orcFilter + } + + def testFilter( + schema: StructType, + filters: Seq[Filter], + expected: SearchArgument): Unit = { + val caseSensitiveFilters = getOrcFilter(schema, filters, "true") + val caseInsensitiveFilters = getOrcFilter(schema, filters, "false") + + assert(caseSensitiveFilters.isEmpty) + assert(caseInsensitiveFilters.isDefined) + + assert(caseInsensitiveFilters.get.getLeaves().size() > 0) + assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size()) + (0 until expected.getLeaves().size()).foreach { index => + assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index)) + } + } + + val schema1 = StructType(Seq(StructField("cint", IntegerType))) + testFilter(schema1, Seq(GreaterThan("CINT", 1)), + newBuilder.startNot() + .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema1, Seq( + And(GreaterThan("CINT", 1), EqualTo("Cint", 2))), + newBuilder.startAnd() + .startNot() + .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() + .equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) + .`end`().build()) + + // Nested column case + val schema2 = StructType(Seq(StructField("a", + StructType(Seq(StructField("cint", IntegerType)))))) + + testFilter(schema2, Seq(GreaterThan("A.CINT", 1)), + newBuilder.startNot() + .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema2, Seq(GreaterThan("a.CINT", 1)), + newBuilder.startNot() + .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema2, Seq(GreaterThan("A.cint", 1)), + newBuilder.startNot() + .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema2, Seq( + And(GreaterThan("a.CINT", 1), EqualTo("a.Cint", 2))), + newBuilder.startAnd() + .startNot() + .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() + .equals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) + .`end`().build()) + } } diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 5273245fae45..9511fc31f4ac 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -81,7 +81,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], filters: Seq[Filter]): Seq[Filter] = { import org.apache.spark.sql.sources._ @@ -139,7 +139,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Get PredicateLeafType which is corresponding to the given DataType. */ - private def getPredicateLeafType(dataType: DataType) = dataType match { + def getPredicateLeafType(dataType: DataType): PredicateLeaf.Type = dataType match { case BooleanType => PredicateLeaf.Type.BOOLEAN case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG case FloatType | DoubleType => PredicateLeaf.Type.FLOAT @@ -179,7 +179,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], expression: Filter, builder: Builder): Builder = { import org.apache.spark.sql.sources._ @@ -215,11 +215,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], expression: Filter, builder: Builder): Option[Builder] = { def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute)) + getPredicateLeafType(dataTypeMap(attribute).fieldType) import org.apache.spark.sql.sources._ @@ -228,38 +228,46 @@ private[sql] object OrcFilters extends OrcFiltersBase { // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). expression match { case EqualTo(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().equals(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case EqualNullSafe(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case LessThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case GreaterThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startNot() + .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startNot() + .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) case IsNull(name) if dataTypeMap.contains(name) => - Some(builder.startAnd().isNull(name, getType(name)).end()) + Some(builder.startAnd() + .isNull(dataTypeMap(name).fieldName, getType(name)).end()) case IsNotNull(name) if dataTypeMap.contains(name) => - Some(builder.startNot().isNull(name, getType(name)).end()) + Some(builder.startNot() + .isNull(dataTypeMap(name).fieldName, getType(name)).end()) case In(name, values) if dataTypeMap.contains(name) => - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) - Some(builder.startAnd().in(name, getType(name), + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType)) + Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 84cd2777da1d..afc83d7c395f 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} +import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row} @@ -587,8 +588,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1))) val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0")) - // TODO: ORC predicate pushdown should work under case-insensitive analysis. - // assert(actual.count() == 1) + assert(actual.count() == 1) } } @@ -607,5 +607,71 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } } + + test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") { + import org.apache.spark.sql.sources._ + + def getOrcFilter( + schema: StructType, + filters: Seq[Filter], + caseSensitive: String): Option[SearchArgument] = { + var orcFilter: Option[SearchArgument] = None + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + orcFilter = + OrcFilters.createFilter(schema, filters) + } + orcFilter + } + + def testFilter( + schema: StructType, + filters: Seq[Filter], + expected: SearchArgument): Unit = { + val caseSensitiveFilters = getOrcFilter(schema, filters, "true") + val caseInsensitiveFilters = getOrcFilter(schema, filters, "false") + + assert(caseSensitiveFilters.isEmpty) + assert(caseInsensitiveFilters.isDefined) + + assert(caseInsensitiveFilters.get.getLeaves().size() > 0) + assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size()) + (0 until expected.getLeaves().size()).foreach { index => + assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index)) + } + } + + val schema1 = StructType(Seq(StructField("cint", IntegerType))) + testFilter(schema1, Seq(GreaterThan("CINT", 1)), + newBuilder.startNot() + .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema1, Seq( + And(GreaterThan("CINT", 1), EqualTo("Cint", 2))), + newBuilder.startAnd() + .startNot() + .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() + .equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) + .`end`().build()) + + // Nested column case + val schema2 = StructType(Seq(StructField("a", + StructType(Seq(StructField("cint", IntegerType)))))) + + testFilter(schema2, Seq(GreaterThan("A.CINT", 1)), + newBuilder.startNot() + .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema2, Seq(GreaterThan("a.CINT", 1)), + newBuilder.startNot() + .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema2, Seq(GreaterThan("A.cint", 1)), + newBuilder.startNot() + .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema2, Seq( + And(GreaterThan("a.CINT", 1), EqualTo("a.Cint", 2))), + newBuilder.startAnd() + .startNot() + .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() + .equals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) + .`end`().build()) + } }