diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 71bab624f06a8..d90804f4b6ff6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.internal.SQLConf /** * Conversion helpers for working with v2 [[CatalogPlugin]]. @@ -132,4 +134,10 @@ private[sql] object CatalogV2Implicits { part } } + + private lazy val catalystSqlParser = new CatalystSqlParser(SQLConf.get) + + def parseColumnPath(name: String): Seq[String] = { + catalystSqlParser.parseMultipartIdentifier(name) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9a524defb2816..a0e991ff1fd1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2049,6 +2049,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val NESTED_PREDICATE_PUSHDOWN_ENABLED = + buildConf("spark.sql.optimizer.nestedPredicatePushdown.enabled") + .internal() + .doc("When true, Spark tries to push down predicates for nested columns and or names " + + "containing `dots` to data sources. Currently, Parquet implements both optimizations " + + "while ORC only supports predicates for names containing `dots`. The other data sources" + + "don't support this feature yet.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED = buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled") .internal() @@ -3035,6 +3046,8 @@ class SQLConf extends Serializable with Logging { def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED) + def nestedPredicatePushdownEnabled: Boolean = getConf(NESTED_PREDICATE_PUSHDOWN_ENABLED) + def serializerNestedSchemaPruningEnabled: Boolean = getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 020dd79f8f0d7..319073e4475be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Evolving, Stable} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -32,6 +33,10 @@ import org.apache.spark.annotation.{Evolving, Stable} sealed abstract class Filter { /** * List of columns that are referenced by this filter. + * + * Note that, each element in `references` represents a column; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + * * @since 2.1.0 */ def references: Array[String] @@ -40,12 +45,32 @@ sealed abstract class Filter { case f: Filter => f.references case _ => Array.empty } + + /** + * List of columns that are referenced by this filter. + * + * @return each element is a column name as an array of string multi-identifier + * @since 3.0.0 + */ + def v2references: Array[Array[String]] = { + this.references.map(parseColumnPath(_).toArray) + } + + /** + * If any of the references of this filter contains nested column + */ + private[sql] def containsNestedColumn: Boolean = { + this.v2references.exists(_.length > 1) + } } /** - * A filter that evaluates to `true` iff the attribute evaluates to a value + * A filter that evaluates to `true` iff the column evaluates to a value * equal to `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable @@ -58,6 +83,9 @@ case class EqualTo(attribute: String, value: Any) extends Filter { * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` * (rather than NULL) if one of the input is NULL and the other is not NULL. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.5.0 */ @Stable @@ -69,6 +97,9 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable @@ -80,6 +111,9 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than or equal to `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable @@ -91,6 +125,9 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { * A filter that evaluates to `true` iff the attribute evaluates to a value * less than `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable @@ -102,6 +139,9 @@ case class LessThan(attribute: String, value: Any) extends Filter { * A filter that evaluates to `true` iff the attribute evaluates to a value * less than or equal to `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable @@ -112,6 +152,9 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter { /** * A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable @@ -139,6 +182,9 @@ case class In(attribute: String, values: Array[Any]) extends Filter { /** * A filter that evaluates to `true` iff the attribute evaluates to null. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable @@ -149,6 +195,9 @@ case class IsNull(attribute: String) extends Filter { /** * A filter that evaluates to `true` iff the attribute evaluates to a non-null value. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable @@ -190,6 +239,9 @@ case class Not(child: Filter) extends Filter { * A filter that evaluates to `true` iff the attribute evaluates to * a string that starts with `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.1 */ @Stable @@ -201,6 +253,9 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { * A filter that evaluates to `true` iff the attribute evaluates to * a string that ends with `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.1 */ @Stable @@ -212,6 +267,9 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { * A filter that evaluates to `true` iff the attribute evaluates to * a string that contains the string `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.1 */ @Stable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1641b660a271d..faf37609ad814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -652,10 +652,19 @@ object DataSourceStrategy { */ object PushableColumn { def unapply(e: Expression): Option[String] = { - def helper(e: Expression) = e match { - case a: Attribute => Some(a.name) + val nestedPredicatePushdownEnabled = SQLConf.get.nestedPredicatePushdownEnabled + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + def helper(e: Expression): Option[Seq[String]] = e match { + case a: Attribute => + if (nestedPredicatePushdownEnabled || !a.name.contains(".")) { + Some(Seq(a.name)) + } else { + None + } + case s: GetStructField if nestedPredicatePushdownEnabled => + helper(s.child).map(_ :+ s.childSchema(s.ordinal).name) case _ => None } - helper(e) + helper(e).map(_.quoted) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 07065018343cf..f206f59dacdc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.SparkFilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} +import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveComparator, PrimitiveType, Type} import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ @@ -49,15 +49,35 @@ class ParquetFilters( pushDownInFilterThreshold: Int, caseSensitive: Boolean) { // A map which contains parquet field name and data type, if predicate push down applies. - private val nameToParquetField : Map[String, ParquetField] = { - // Here we don't flatten the fields in the nested schema but just look up through - // root fields. Currently, accessing to nested fields does not push down filters - // and it does not support to create filters for them. - val primitiveFields = - schema.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => - f.getName -> ParquetField(f.getName, - ParquetSchemaType(f.getOriginalType, - f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)) + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField : Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + case p: PrimitiveType => + Some(ParquetPrimitiveField(fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType(p.getOriginalType, + p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala).map { field => + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + (field.fieldNames.toSeq.quoted, field) } if (caseSensitive) { primitiveFields.toMap @@ -75,13 +95,13 @@ class ParquetFilters( } /** - * Holds a single field information stored in the underlying parquet file. + * Holds a single primitive field information stored in the underlying parquet file. * - * @param fieldName field name in parquet file + * @param fieldNames a field name as an array of string multi-identifier in parquet file * @param fieldType field type related info in parquet file */ - private case class ParquetField( - fieldName: String, + private case class ParquetPrimitiveField( + fieldNames: Array[String], fieldType: ParquetSchemaType) private case class ParquetSchemaType( @@ -472,13 +492,8 @@ class ParquetFilters( case _ => false } - // Parquet does not allow dots in the column name because dots are used as a column path - // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates - // with missing columns. The incorrect results could be got from Parquet when we push down - // filters for the column having dots in the names. Thus, we do not push down such filters. - // See SPARK-20364. private def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) } /** @@ -509,38 +524,38 @@ class ParquetFilters( predicate match { case sources.IsNull(name) if canMakeFilterOn(name, null) => makeEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), null)) + .map(_(nameToParquetField(name).fieldNames, null)) case sources.IsNotNull(name) if canMakeFilterOn(name, null) => makeNotEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), null)) + .map(_(nameToParquetField(name).fieldNames, null)) case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.LessThan(name, value) if canMakeFilterOn(name, value) => makeLt.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeLtEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => makeGt.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeGtEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.And(lhs, rhs) => // At here, it is not safe to just convert one side and remove the other side @@ -591,13 +606,13 @@ class ParquetFilters( && values.distinct.length <= pushDownInFilterThreshold => values.distinct.flatMap { v => makeEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), v)) + .map(_(nameToParquetField(name).fieldNames, v)) }.reduceLeftOption(FilterApi.or) case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name, prefix) => Option(prefix).map { v => - FilterApi.userDefined(binaryColumn(Array(nameToParquetField(name).fieldName)), + FilterApi.userDefined(binaryColumn(nameToParquetField(name).fieldNames), new UserDefinedPredicate[Binary] with Serializable { private val strToBinary = Binary.fromReusedByteArray(v.getBytes) private val size = strToBinary.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 1421ffd8b6de4..9f40f5faa2e99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters @@ -59,8 +60,10 @@ case class OrcScanBuilder( // changed `hadoopConf` in executors. OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames) } - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap - _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. + val newFilters = filters.filter(!_.containsNestedColumn) + _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, newFilters).toArray } filters } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 7bd3213b378ce..a775a97895cfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -26,15 +26,61 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class DataSourceStrategySuite extends PlanTest with SharedSparkSession { val attrInts = Seq( - 'cint.int + 'cint.int, + Symbol("c.int").int, + GetStructField('a.struct(StructType( + StructField("cstr", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None), + GetStructField('a.struct(StructType( + StructField("c.int", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 0, None), + GetStructField(Symbol("a.b").struct(StructType( + StructField("cstr1", StringType, nullable = true) :: + StructField("cstr2", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None), + GetStructField(Symbol("a.b").struct(StructType( + StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None), + GetStructField(GetStructField('a.struct(StructType( + StructField("cstr1", StringType, nullable = true) :: + StructField("b", StructType(StructField("cint", IntegerType, nullable = true) :: + StructField("cstr2", StringType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) ).zip(Seq( - "cint" + "cint", + "`c.int`", // single level field that contains `dot` in name + "a.cint", // two level nested field + "a.`c.int`", // two level nested field, and nested level contains `dot` + "`a.b`.cint", // two level nested field, and top level contains `dot` + "`a.b`.`c.int`", // two level nested field, and both levels contain `dot` + "a.b.cint" // three level nested field )) val attrStrs = Seq( - 'cstr.string + 'cstr.string, + Symbol("c.str").string, + GetStructField('a.struct(StructType( + StructField("cint", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 1, None), + GetStructField('a.struct(StructType( + StructField("c.str", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None), + GetStructField(Symbol("a.b").struct(StructType( + StructField("cint1", IntegerType, nullable = true) :: + StructField("cint2", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 2, None), + GetStructField(Symbol("a.b").struct(StructType( + StructField("c.str", StringType, nullable = true) :: Nil)), 0, None), + GetStructField(GetStructField('a.struct(StructType( + StructField("cint1", IntegerType, nullable = true) :: + StructField("b", StructType(StructField("cstr", StringType, nullable = true) :: + StructField("cint2", IntegerType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) ).zip(Seq( - "cstr" + "cstr", + "`c.str`", // single level field that contains `dot` in name + "a.cstr", // two level nested field + "a.`c.str`", // two level nested field, and nested level contains `dot` + "`a.b`.cstr", // two level nested field, and top level contains `dot` + "`a.b`.`c.str`", // two level nested field, and both levels contain `dot` + "a.b.cstr" // three level nested field )) test("translate simple expression") { attrInts.zip(attrStrs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4e0c1c2dbe601..d1161e33b0941 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -103,22 +103,42 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } - private def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit df: DataFrame): Unit = { - def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { - df.rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted - } + /** + * Takes single level `inputDF` dataframe to generate multi-level nested + * dataframes as new test data. + */ + private def withNestedDataFrame(inputDF: DataFrame) + (runTest: (DataFrame, String, Any => Any) => Unit): Unit = { + assert(inputDF.schema.fields.length == 1) + assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType]) + val df = inputDF.toDF("temp") + Seq( + ( + df.withColumnRenamed("temp", "a"), + "a", // zero nesting + (x: Any) => x), + ( + df.withColumn("a", struct(df("temp") as "b")).drop("temp"), + "a.b", // one level nesting + (x: Any) => Row(x)), + ( + df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"), + "a.b.c", // two level nesting + (x: Any) => Row(Row(x)) + ), + ( + df.withColumnRenamed("temp", "a.b"), + "`a.b`", // zero nesting with column name containing `dots` + (x: Any) => x + ), + ( + df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"), + "`a.b`.`c.d`", // one level nesting with column names containing `dots` + (x: Any) => Row(x) + ) + ).foreach { case (df, colName, resultFun) => + runTest(df, colName, resultFun) } - - checkFilterPredicate(df, predicate, filterClass, checkBinaryAnswer _, expected) - } - - private def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit df: DataFrame): Unit = { - checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { @@ -128,36 +148,38 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val ts3 = data(2) val ts4 = data(3) - withParquetDataFrame(data.map(i => Tuple1(i))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i))) - - checkFilterPredicate('_1 === ts1, classOf[Eq[_]], ts1) - checkFilterPredicate('_1 <=> ts1, classOf[Eq[_]], ts1) - checkFilterPredicate('_1 =!= ts1, classOf[NotEq[_]], - Seq(ts2, ts3, ts4).map(i => Row.apply(i))) - - checkFilterPredicate('_1 < ts2, classOf[Lt[_]], ts1) - checkFilterPredicate('_1 > ts1, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(i))) - checkFilterPredicate('_1 <= ts1, classOf[LtEq[_]], ts1) - checkFilterPredicate('_1 >= ts4, classOf[GtEq[_]], ts4) - - checkFilterPredicate(Literal(ts1) === '_1, classOf[Eq[_]], ts1) - checkFilterPredicate(Literal(ts1) <=> '_1, classOf[Eq[_]], ts1) - checkFilterPredicate(Literal(ts2) > '_1, classOf[Lt[_]], ts1) - checkFilterPredicate(Literal(ts3) < '_1, classOf[Gt[_]], ts4) - checkFilterPredicate(Literal(ts1) >= '_1, classOf[LtEq[_]], ts1) - checkFilterPredicate(Literal(ts4) <= '_1, classOf[GtEq[_]], ts4) - - checkFilterPredicate(!('_1 < ts4), classOf[GtEq[_]], ts4) - checkFilterPredicate('_1 < ts2 || '_1 > ts3, classOf[Operators.Or], Seq(Row(ts1), Row(ts4))) - } - } - - private def testDecimalPushDown(data: DataFrame)(f: DataFrame => Unit): Unit = { - withTempPath { file => - data.write.parquet(file.getCanonicalPath) - readParquetFile(file.toString)(f) + import testImplicits._ + withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val tsAttr = df(colName).expr + assert(df(colName).expr.dataType === TimestampType) + + checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]], + data.map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tsAttr === ts1, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr <=> ts1, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr =!= ts1, classOf[NotEq[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tsAttr < ts2, classOf[Lt[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr > ts1, classOf[Gt[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) + checkFilterPredicate(tsAttr <= ts1, classOf[LtEq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr >= ts4, classOf[GtEq[_]], resultFun(ts4)) + + checkFilterPredicate(Literal(ts1) === tsAttr, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts1) <=> tsAttr, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts2) > tsAttr, classOf[Lt[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts3) < tsAttr, classOf[Gt[_]], resultFun(ts4)) + checkFilterPredicate(Literal(ts1) >= tsAttr, classOf[LtEq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts4) <= tsAttr, classOf[GtEq[_]], resultFun(ts4)) + + checkFilterPredicate(!(tsAttr < ts4), classOf[GtEq[_]], resultFun(ts4)) + checkFilterPredicate(tsAttr < ts2 || tsAttr > ts3, classOf[Operators.Or], + Seq(Row(resultFun(ts1)), Row(resultFun(ts4)))) + } } } @@ -187,201 +209,273 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - boolean") { - withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - - checkFilterPredicate('_1 === true, classOf[Eq[_]], true) - checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) - checkFilterPredicate('_1 =!= true, classOf[NotEq[_]], false) + val data = (true :: false :: Nil).map(b => Tuple1.apply(Option(b))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val booleanAttr = df(colName).expr + assert(df(colName).expr.dataType === BooleanType) + + checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]], + Seq(Row(resultFun(true)), Row(resultFun(false)))) + + checkFilterPredicate(booleanAttr === true, classOf[Eq[_]], resultFun(true)) + checkFilterPredicate(booleanAttr <=> true, classOf[Eq[_]], resultFun(true)) + checkFilterPredicate(booleanAttr =!= true, classOf[NotEq[_]], resultFun(false)) + } } } test("filter pushdown - tinyint") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => - assert(df.schema.head.dataType === ByteType) - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte, - classOf[Operators.Or], Seq(Row(1), Row(4))) + val data = (1 to 4).map(i => Tuple1(Option(i.toByte))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val tinyIntAttr = df(colName).expr + assert(df(colName).expr.dataType === ByteType) + + checkFilterPredicate(tinyIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(tinyIntAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tinyIntAttr === 1.toByte, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(tinyIntAttr <=> 1.toByte, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(tinyIntAttr =!= 1.toByte, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tinyIntAttr < 2.toByte, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(tinyIntAttr > 3.toByte, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(tinyIntAttr <= 1.toByte, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(tinyIntAttr >= 4.toByte, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1.toByte) === tinyIntAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1.toByte) <=> tinyIntAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2.toByte) > tinyIntAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3.toByte) < tinyIntAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1.toByte) >= tinyIntAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4.toByte) <= tinyIntAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(tinyIntAttr < 4.toByte), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(tinyIntAttr < 2.toByte || tinyIntAttr > 3.toByte, + classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - smallint") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => - assert(df.schema.head.dataType === ShortType) - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort, - classOf[Operators.Or], Seq(Row(1), Row(4))) + val data = (1 to 4).map(i => Tuple1(Option(i.toShort))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val smallIntAttr = df(colName).expr + assert(df(colName).expr.dataType === ShortType) + + checkFilterPredicate(smallIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(smallIntAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(smallIntAttr === 1.toShort, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(smallIntAttr <=> 1.toShort, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(smallIntAttr =!= 1.toShort, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(smallIntAttr < 2.toShort, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(smallIntAttr > 3.toShort, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(smallIntAttr <= 1.toShort, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(smallIntAttr >= 4.toShort, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1.toShort) === smallIntAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1.toShort) <=> smallIntAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2.toShort) > smallIntAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3.toShort) < smallIntAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1.toShort) >= smallIntAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4.toShort) <= smallIntAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(smallIntAttr < 4.toShort), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(smallIntAttr < 2.toShort || smallIntAttr > 3.toShort, + classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - integer") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + val data = (1 to 4).map(i => Tuple1(Option(i))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val intAttr = df(colName).expr + assert(df(colName).expr.dataType === IntegerType) + + checkFilterPredicate(intAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(intAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(intAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(intAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(intAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(intAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(intAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(intAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(intAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === intAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> intAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > intAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < intAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= intAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= intAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(intAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(intAttr < 2 || intAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - long") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + val data = (1 to 4).map(i => Tuple1(Option(i.toLong))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val longAttr = df(colName).expr + assert(df(colName).expr.dataType === LongType) + + checkFilterPredicate(longAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(longAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(longAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(longAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(longAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(longAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(longAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(longAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(longAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === longAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> longAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > longAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < longAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= longAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= longAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(longAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(longAttr < 2 || longAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - float") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + val data = (1 to 4).map(i => Tuple1(Option(i.toFloat))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val floatAttr = df(colName).expr + assert(df(colName).expr.dataType === FloatType) + + checkFilterPredicate(floatAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(floatAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(floatAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(floatAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(floatAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(floatAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(floatAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(floatAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(floatAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === floatAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> floatAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > floatAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < floatAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= floatAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= floatAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(floatAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(floatAttr < 2 || floatAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - double") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + val data = (1 to 4).map(i => Tuple1(Option(i.toDouble))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val doubleAttr = df(colName).expr + assert(df(colName).expr.dataType === DoubleType) + + checkFilterPredicate(doubleAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(doubleAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(doubleAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(doubleAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(doubleAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(doubleAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(doubleAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(doubleAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(doubleAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === doubleAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> doubleAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > doubleAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < doubleAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= doubleAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= doubleAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(doubleAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(doubleAttr < 2 || doubleAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - string") { - withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") - checkFilterPredicate('_1 <=> "1", classOf[Eq[_]], "1") - checkFilterPredicate( - '_1 =!= "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") - checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") - checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - - checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") - checkFilterPredicate(Literal("1") <=> '_1, classOf[Eq[_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + val data = (1 to 4).map(i => Tuple1(Option(i.toString))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val stringAttr = df(colName).expr + assert(df(colName).expr.dataType === StringType) + + checkFilterPredicate(stringAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(stringAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i.toString)))) + + checkFilterPredicate(stringAttr === "1", classOf[Eq[_]], resultFun("1")) + checkFilterPredicate(stringAttr <=> "1", classOf[Eq[_]], resultFun("1")) + checkFilterPredicate(stringAttr =!= "1", classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i.toString)))) + + checkFilterPredicate(stringAttr < "2", classOf[Lt[_]], resultFun("1")) + checkFilterPredicate(stringAttr > "3", classOf[Gt[_]], resultFun("4")) + checkFilterPredicate(stringAttr <= "1", classOf[LtEq[_]], resultFun("1")) + checkFilterPredicate(stringAttr >= "4", classOf[GtEq[_]], resultFun("4")) + + checkFilterPredicate(Literal("1") === stringAttr, classOf[Eq[_]], resultFun("1")) + checkFilterPredicate(Literal("1") <=> stringAttr, classOf[Eq[_]], resultFun("1")) + checkFilterPredicate(Literal("2") > stringAttr, classOf[Lt[_]], resultFun("1")) + checkFilterPredicate(Literal("3") < stringAttr, classOf[Gt[_]], resultFun("4")) + checkFilterPredicate(Literal("1") >= stringAttr, classOf[LtEq[_]], resultFun("1")) + checkFilterPredicate(Literal("4") <= stringAttr, classOf[GtEq[_]], resultFun("4")) + + checkFilterPredicate(!(stringAttr < "4"), classOf[GtEq[_]], resultFun("4")) + checkFilterPredicate(stringAttr < "2" || stringAttr > "3", classOf[Operators.Or], + Seq(Row(resultFun("1")), Row(resultFun("4")))) + } } } @@ -390,32 +484,39 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) } - withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => - checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b) - - checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkBinaryFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate( - '_1 =!= 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate(Literal(1.b) <=> '_1, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate( - '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) + val data = (1 to 4).map(i => Tuple1(Option(i.b))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val binaryAttr: Expression = df(colName).expr + assert(df(colName).expr.dataType === BinaryType) + + checkFilterPredicate(binaryAttr === 1.b, classOf[Eq[_]], resultFun(1.b)) + checkFilterPredicate(binaryAttr <=> 1.b, classOf[Eq[_]], resultFun(1.b)) + + checkFilterPredicate(binaryAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(binaryAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i.b)))) + + checkFilterPredicate(binaryAttr =!= 1.b, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i.b)))) + + checkFilterPredicate(binaryAttr < 2.b, classOf[Lt[_]], resultFun(1.b)) + checkFilterPredicate(binaryAttr > 3.b, classOf[Gt[_]], resultFun(4.b)) + checkFilterPredicate(binaryAttr <= 1.b, classOf[LtEq[_]], resultFun(1.b)) + checkFilterPredicate(binaryAttr >= 4.b, classOf[GtEq[_]], resultFun(4.b)) + + checkFilterPredicate(Literal(1.b) === binaryAttr, classOf[Eq[_]], resultFun(1.b)) + checkFilterPredicate(Literal(1.b) <=> binaryAttr, classOf[Eq[_]], resultFun(1.b)) + checkFilterPredicate(Literal(2.b) > binaryAttr, classOf[Lt[_]], resultFun(1.b)) + checkFilterPredicate(Literal(3.b) < binaryAttr, classOf[Gt[_]], resultFun(4.b)) + checkFilterPredicate(Literal(1.b) >= binaryAttr, classOf[LtEq[_]], resultFun(1.b)) + checkFilterPredicate(Literal(4.b) <= binaryAttr, classOf[GtEq[_]], resultFun(4.b)) + + checkFilterPredicate(!(binaryAttr < 4.b), classOf[GtEq[_]], resultFun(4.b)) + checkFilterPredicate(binaryAttr < 2.b || binaryAttr > 3.b, classOf[Operators.Or], + Seq(Row(resultFun(1.b)), Row(resultFun(4.b)))) + } } } @@ -424,40 +525,53 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared def date: Date = Date.valueOf(s) } - val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21") - - withParquetDataFrame(data.map(i => Tuple1(i.date))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date))) - - checkFilterPredicate('_1 === "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) - checkFilterPredicate('_1 <=> "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) - checkFilterPredicate('_1 =!= "2018-03-18".date, classOf[NotEq[_]], - Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(i.date))) - - checkFilterPredicate('_1 < "2018-03-19".date, classOf[Lt[_]], "2018-03-18".date) - checkFilterPredicate('_1 > "2018-03-20".date, classOf[Gt[_]], "2018-03-21".date) - checkFilterPredicate('_1 <= "2018-03-18".date, classOf[LtEq[_]], "2018-03-18".date) - checkFilterPredicate('_1 >= "2018-03-21".date, classOf[GtEq[_]], "2018-03-21".date) - - checkFilterPredicate( - Literal("2018-03-18".date) === '_1, classOf[Eq[_]], "2018-03-18".date) - checkFilterPredicate( - Literal("2018-03-18".date) <=> '_1, classOf[Eq[_]], "2018-03-18".date) - checkFilterPredicate( - Literal("2018-03-19".date) > '_1, classOf[Lt[_]], "2018-03-18".date) - checkFilterPredicate( - Literal("2018-03-20".date) < '_1, classOf[Gt[_]], "2018-03-21".date) - checkFilterPredicate( - Literal("2018-03-18".date) >= '_1, classOf[LtEq[_]], "2018-03-18".date) - checkFilterPredicate( - Literal("2018-03-21".date) <= '_1, classOf[GtEq[_]], "2018-03-21".date) - - checkFilterPredicate(!('_1 < "2018-03-21".date), classOf[GtEq[_]], "2018-03-21".date) - checkFilterPredicate( - '_1 < "2018-03-19".date || '_1 > "2018-03-20".date, - classOf[Operators.Or], - Seq(Row("2018-03-18".date), Row("2018-03-21".date))) + val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21").map(_.date) + import testImplicits._ + withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val dateAttr: Expression = df(colName).expr + assert(df(colName).expr.dataType === DateType) + + checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]], + data.map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(dateAttr === "2018-03-18".date, classOf[Eq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(dateAttr <=> "2018-03-18".date, classOf[Eq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(dateAttr =!= "2018-03-18".date, classOf[NotEq[_]], + Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(resultFun(i.date)))) + + checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]], + resultFun("2018-03-21".date)) + checkFilterPredicate(dateAttr <= "2018-03-18".date, classOf[LtEq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(dateAttr >= "2018-03-21".date, classOf[GtEq[_]], + resultFun("2018-03-21".date)) + + checkFilterPredicate(Literal("2018-03-18".date) === dateAttr, classOf[Eq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr, classOf[Eq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(Literal("2018-03-19".date) > dateAttr, classOf[Lt[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(Literal("2018-03-20".date) < dateAttr, classOf[Gt[_]], + resultFun("2018-03-21".date)) + checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr, classOf[LtEq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr, classOf[GtEq[_]], + resultFun("2018-03-21".date)) + + checkFilterPredicate(!(dateAttr < "2018-03-21".date), classOf[GtEq[_]], + resultFun("2018-03-21".date)) + checkFilterPredicate( + dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date, + classOf[Operators.Or], + Seq(Row(resultFun("2018-03-18".date)), Row(resultFun("2018-03-21".date)))) + } } } @@ -485,7 +599,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> ParquetOutputTimestampType.INT96.toString) { - withParquetDataFrame(millisData.map(i => Tuple1(i))) { implicit df => + import testImplicits._ + withParquetDataFrame(millisData.map(i => Tuple1(i)).toDF()) { implicit df => val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) assertResult(None) { createParquetFilters(schema).createFilter(sources.IsNull("_1")) @@ -502,33 +617,39 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared (false, DecimalType.MAX_PRECISION) // binaryWriterUsingUnscaledBytes ).foreach { case (legacyFormat, precision) => withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> legacyFormat.toString) { - val schema = StructType.fromDDL(s"a decimal($precision, 2)") val rdd = spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) - val dataFrame = spark.createDataFrame(rdd, schema) - testDecimalPushDown(dataFrame) { implicit df => - assert(df.schema === schema) - checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('a === 1, classOf[Eq[_]], 1) - checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('a < 2, classOf[Lt[_]], 1) - checkFilterPredicate('a > 3, classOf[Gt[_]], 4) - checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + val dataFrame = spark.createDataFrame(rdd, StructType.fromDDL(s"a decimal($precision, 2)")) + withNestedDataFrame(dataFrame) { case (inputDF, colName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val decimalAttr: Expression = df(colName).expr + assert(df(colName).expr.dataType === DecimalType(precision, 2)) + + checkFilterPredicate(decimalAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(decimalAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(decimalAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(decimalAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(decimalAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(decimalAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(decimalAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(decimalAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(decimalAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === decimalAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> decimalAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > decimalAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < decimalAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= decimalAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= decimalAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(decimalAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(decimalAttr < 2 || decimalAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } } @@ -1042,7 +1163,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("SPARK-16371 Do not push down filters when inner name and outer name are the same") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i)))) { implicit df => + import testImplicits._ + withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i))).toDF()) { implicit df => // Here the schema becomes as below: // // root @@ -1107,7 +1229,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } - test("SPARK-20364: Disable Parquet predicate pushdown for fields having dots in the names") { + test("SPARK-31026: Parquet predicate pushdown for fields having dots in the names") { import testImplicits._ Seq(true, false).foreach { vectorized => @@ -1120,6 +1242,28 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared assert(readBack.count() == 1) } } + + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString, + // Makes sure disabling 'spark.sql.parquet.recordFilter' still enables + // row group level filtering. + SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> "false", + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + + withTempPath { path => + val data = (1 to 1024) + data.toDF("col.dots").coalesce(1) + .write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath).filter("`col.dots` == 500") + // Here, we strip the Spark side filter and check the actual results from Parquet. + val actual = stripSparkFilter(df).collect().length + // Since those are filtered at row group level, the result count should be less + // than the total length but should not be a single record. + // Note that, if record level filtering is enabled, it should be a single record. + // If no filter is pushed down to Parquet, it should be the total length of data. + assert(actual > 1 && actual < data.length) + } + } } } @@ -1162,7 +1306,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - StringStartsWith") { - withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => + withParquetDataFrame { + import testImplicits._ + (1 to 4).map(i => Tuple1(i + "str" + i)).toDF() + } { implicit df => checkFilterPredicate( '_1.startsWith("").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], @@ -1208,7 +1355,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } // SPARK-28371: make sure filter is null-safe. - withParquetDataFrame(Seq(Tuple1[String](null))) { implicit df => + withParquetDataFrame { + import testImplicits._ + Seq(Tuple1[String](null)).toDF() + } { implicit df => checkFilterPredicate( '_1.startsWith("blah").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 7f85fd2a1629a..497b823868450 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -82,7 +82,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession * Writes `data` to a Parquet file, reads it back and check file contents. */ protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple))) + withParquetDataFrame(data.toDF())(r => checkAnswer(r, data.map(Row.fromTuple))) } test("basic data types (without binary)") { @@ -94,7 +94,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession test("raw binary") { val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetDataFrame(data) { df => + withParquetDataFrame(data.toDF()) { df => assertResult(data.map(_._1.mkString(",")).sorted) { df.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted } @@ -197,7 +197,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession testStandardAndLegacyModes("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetDataFrame(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) @@ -214,7 +214,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDataFrame(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(array) => Row(array.map(struct => Row(struct.productIterator.toSeq: _*))) @@ -233,7 +233,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDataFrame(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(array) => Row(array.map { case Tuple1(Tuple1(str)) => Row(Row(str))}) @@ -243,7 +243,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession testStandardAndLegacyModes("nested struct with array of array as field") { val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetDataFrame(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) @@ -260,7 +260,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDataFrame(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(m) => Row(m.map { case (k, v) => Row(k.productIterator.toSeq: _*) -> v }) @@ -277,7 +277,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDataFrame(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(m) => Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) @@ -293,7 +293,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession null.asInstanceOf[java.lang.Float], null.asInstanceOf[java.lang.Double]) - withParquetDataFrame(allNulls :: Nil) { df => + withParquetDataFrame((allNulls :: Nil).toDF()) { df => val rows = df.collect() assert(rows.length === 1) assert(rows.head === Row(Seq.fill(5)(null): _*)) @@ -306,7 +306,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession None.asInstanceOf[Option[Long]], None.asInstanceOf[Option[String]]) - withParquetDataFrame(allNones :: Nil) { df => + withParquetDataFrame((allNones :: Nil).toDF()) { df => val rows = df.collect() assert(rows.length === 1) assert(rows.head === Row(Seq.fill(3)(null): _*)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 828ba6aee026b..f2dbc536ac566 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -63,12 +63,16 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { (f: String => Unit): Unit = withDataSourceFile(data)(f) /** - * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], + * Writes `df` dataframe to a Parquet file and reads it back as a [[DataFrame]], * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ - protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] - (data: Seq[T], testVectorized: Boolean = true) - (f: DataFrame => Unit): Unit = withDataSourceDataFrame(data, testVectorized)(f) + protected def withParquetDataFrame(df: DataFrame, testVectorized: Boolean = true) + (f: DataFrame => Unit): Unit = { + withTempPath { file => + df.write.format(dataSourceName).save(file.getCanonicalPath) + readFile(file.getCanonicalPath, testVectorized)(f) + } + } /** * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala index 1cb7a2156c3d3..33b2db57d9f0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala @@ -24,66 +24,143 @@ import org.apache.spark.SparkFunSuite */ class FiltersSuite extends SparkFunSuite { - test("EqualTo references") { - assert(EqualTo("a", "1").references.toSeq == Seq("a")) - assert(EqualTo("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) + private def withFieldNames(f: (String, Array[String]) => Unit): Unit = { + Seq(("a", Array("a")), + ("a.b", Array("a", "b")), + ("`a.b`.c", Array("a.b", "c")), + ("`a.b`.`c.d`.`e.f`", Array("a.b", "c.d", "e.f")) + ).foreach { case (name, fieldNames) => + f(name, fieldNames) + } } - test("EqualNullSafe references") { - assert(EqualNullSafe("a", "1").references.toSeq == Seq("a")) - assert(EqualNullSafe("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } + test("EqualTo references") { withFieldNames { (name, fieldNames) => + assert(EqualTo(name, "1").references.toSeq == Seq(name)) + assert(EqualTo(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - test("GreaterThan references") { - assert(GreaterThan("a", "1").references.toSeq == Seq("a")) - assert(GreaterThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } + assert(EqualTo(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(EqualTo("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) - test("GreaterThanOrEqual references") { - assert(GreaterThanOrEqual("a", "1").references.toSeq == Seq("a")) - assert(GreaterThanOrEqual("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } + assert(EqualTo(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(EqualTo("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} - test("LessThan references") { - assert(LessThan("a", "1").references.toSeq == Seq("a")) - assert(LessThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } + test("EqualNullSafe references") { withFieldNames { (name, fieldNames) => + assert(EqualNullSafe(name, "1").references.toSeq == Seq(name)) + assert(EqualNullSafe(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - test("LessThanOrEqual references") { - assert(LessThanOrEqual("a", "1").references.toSeq == Seq("a")) - assert(LessThanOrEqual("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } + assert(EqualNullSafe(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(EqualNullSafe("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) - test("In references") { - assert(In("a", Array("1")).references.toSeq == Seq("a")) - assert(In("a", Array("1", EqualTo("b", "2"))).references.toSeq == Seq("a", "b")) - } + assert(EqualNullSafe(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(EqualNullSafe("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} - test("IsNull references") { - assert(IsNull("a").references.toSeq == Seq("a")) - } + test("GreaterThan references") { withFieldNames { (name, fieldNames) => + assert(GreaterThan(name, "1").references.toSeq == Seq(name)) + assert(GreaterThan(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - test("IsNotNull references") { - assert(IsNotNull("a").references.toSeq == Seq("a")) - } + assert(GreaterThan(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(GreaterThan("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) - test("And references") { - assert(And(EqualTo("a", "1"), EqualTo("b", "1")).references.toSeq == Seq("a", "b")) - } + assert(GreaterThan(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(GreaterThan("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} - test("Or references") { - assert(Or(EqualTo("a", "1"), EqualTo("b", "1")).references.toSeq == Seq("a", "b")) - } + test("GreaterThanOrEqual references") { withFieldNames { (name, fieldNames) => + assert(GreaterThanOrEqual(name, "1").references.toSeq == Seq(name)) + assert(GreaterThanOrEqual(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - test("StringStartsWith references") { - assert(StringStartsWith("a", "str").references.toSeq == Seq("a")) - } + assert(GreaterThanOrEqual(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(GreaterThanOrEqual("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) - test("StringEndsWith references") { - assert(StringEndsWith("a", "str").references.toSeq == Seq("a")) - } + assert(GreaterThanOrEqual(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(GreaterThanOrEqual("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} - test("StringContains references") { - assert(StringContains("a", "str").references.toSeq == Seq("a")) - } + test("LessThan references") { withFieldNames { (name, fieldNames) => + assert(LessThan(name, "1").references.toSeq == Seq(name)) + assert(LessThan(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + + assert(LessThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) + }} + + test("LessThanOrEqual references") { withFieldNames { (name, fieldNames) => + assert(LessThanOrEqual(name, "1").references.toSeq == Seq(name)) + assert(LessThanOrEqual(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + + assert(LessThanOrEqual(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(LessThanOrEqual("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) + + assert(LessThanOrEqual(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(LessThanOrEqual("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("In references") { withFieldNames { (name, fieldNames) => + assert(In(name, Array("1")).references.toSeq == Seq(name)) + assert(In(name, Array("1")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + + assert(In(name, Array("1", EqualTo("b", "2"))).references.toSeq == Seq(name, "b")) + assert(In("b", Array("1", EqualTo(name, "2"))).references.toSeq == Seq("b", name)) + + assert(In(name, Array("1", EqualTo("b", "2"))).v2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(In("b", Array("1", EqualTo(name, "2"))).v2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("IsNull references") { withFieldNames { (name, fieldNames) => + assert(IsNull(name).references.toSeq == Seq(name)) + assert(IsNull(name).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + }} + + test("IsNotNull references") { withFieldNames { (name, fieldNames) => + assert(IsNotNull(name).references.toSeq == Seq(name)) + assert(IsNull(name).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + }} + + test("And references") { withFieldNames { (name, fieldNames) => + assert(And(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq == Seq(name, "b")) + assert(And(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq == Seq("b", name)) + + assert(And(EqualTo(name, "1"), EqualTo("b", "1")).v2references.toSeq.map(_.toSeq) == + Seq(fieldNames.toSeq, Seq("b"))) + assert(And(EqualTo("b", "1"), EqualTo(name, "1")).v2references.toSeq.map(_.toSeq) == + Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("Or references") { withFieldNames { (name, fieldNames) => + assert(Or(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq == Seq(name, "b")) + assert(Or(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq == Seq("b", name)) + + assert(Or(EqualTo(name, "1"), EqualTo("b", "1")).v2references.toSeq.map(_.toSeq) == + Seq(fieldNames.toSeq, Seq("b"))) + assert(Or(EqualTo("b", "1"), EqualTo(name, "1")).v2references.toSeq.map(_.toSeq) == + Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("StringStartsWith references") { withFieldNames { (name, fieldNames) => + assert(StringStartsWith(name, "str").references.toSeq == Seq(name)) + assert(StringStartsWith(name, "str").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + }} + + test("StringEndsWith references") { withFieldNames { (name, fieldNames) => + assert(StringEndsWith(name, "str").references.toSeq == Seq(name)) + assert(StringEndsWith(name, "str").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + }} + + test("StringContains references") { withFieldNames { (name, fieldNames) => + assert(StringContains(name, "str").references.toSeq == Seq(name)) + assert(StringContains(name, "str").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + }} } diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index b9cbc484e1fc1..f5abd30854e00 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -65,9 +65,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) + // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. + val newFilters = filters.filter(!_.containsNestedColumn) + val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the @@ -222,48 +224,39 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().equals(name, getType(name), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + case IsNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startAnd().isNull(name, getType(name)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startNot().isNull(name, getType(name)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) - Some(builder.startAnd().in(quotedName, getType(attribute), + case In(name, values) if isSearchableType(dataTypeMap(name)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) + Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 6e9e592be13be..675e089153679 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -65,9 +65,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) + // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. + val newFilters = filters.filter(!_.containsNestedColumn) + val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the @@ -222,48 +224,39 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().equals(name, getType(name), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + case IsNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startAnd().isNull(name, getType(name)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startNot().isNull(name, getType(name)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) - Some(builder.startAnd().in(quotedName, getType(attribute), + case In(name, values) if isSearchableType(dataTypeMap(name)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) + Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index cd1bffb6b7ab7..f9c514567c639 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.SparkException import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.execution.datasources.orc.{OrcFilters => DatasourceOrcFilters} import org.apache.spark.sql.execution.datasources.orc.OrcFilters.buildTree import org.apache.spark.sql.hive.HiveUtils @@ -73,9 +74,11 @@ private[orc] object OrcFilters extends Logging { if (HiveUtils.isHive23) { DatasourceOrcFilters.createFilter(schema, filters).asInstanceOf[Option[SearchArgument]] } else { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. + val newFilters = filters.filter(!_.containsNestedColumn) // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) + val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the