From 385738d91e9e5145ff1f8c707194640acef1cd6e Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 5 Mar 2020 15:35:39 -0800 Subject: [PATCH 1/9] SPARK-17636 --- .../catalog/CatalogV2Implicits.scala | 8 + .../apache/spark/sql/sources/filters.scala | 120 +++++++++- .../datasources/DataSourceStrategy.scala | 8 +- .../datasources/orc/OrcFiltersBase.scala | 9 + .../datasources/parquet/ParquetFilters.scala | 79 ++++--- .../datasources/v2/orc/OrcScanBuilder.scala | 7 +- .../datasources/DataSourceStrategySuite.scala | 54 ++++- .../parquet/ParquetFilterSuite.scala | 165 +++++++++----- .../datasources/parquet/ParquetIOSuite.scala | 20 +- .../datasources/parquet/ParquetTest.scala | 29 ++- .../spark/sql/sources/FiltersSuite.scala | 205 +++++++++++++----- .../datasources/orc/OrcFilters.scala | 65 +++--- .../datasources/orc/OrcFilters.scala | 66 +++--- .../spark/sql/hive/orc/OrcFilters.scala | 7 +- 14 files changed, 599 insertions(+), 243 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 71bab624f06a8..d90804f4b6ff6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.internal.SQLConf /** * Conversion helpers for working with v2 [[CatalogPlugin]]. @@ -132,4 +134,10 @@ private[sql] object CatalogV2Implicits { part } } + + private lazy val catalystSqlParser = new CatalystSqlParser(SQLConf.get) + + def parseColumnPath(name: String): Seq[String] = { + catalystSqlParser.parseMultipartIdentifier(name) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 020dd79f8f0d7..570219a9cdaff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Evolving, Stable} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -32,6 +33,10 @@ import org.apache.spark.annotation.{Evolving, Stable} sealed abstract class Filter { /** * List of columns that are referenced by this filter. + * + * Note that, each element in `references` represents a column; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + * * @since 2.1.0 */ def references: Array[String] @@ -40,17 +45,42 @@ sealed abstract class Filter { case f: Filter => f.references case _ => Array.empty } + + /** + * List of columns that are referenced by this filter. + * + * @return each element is a column name as an array of string multi-identifier + * @since 3.0.0 + */ + def V2references: Array[Array[String]] = { + this.references.map(parseColumnPath(_).toArray) + } + + /** + * If any of the references of this filter contains nested column + */ + private[sql] def containsNestedColumn: Boolean = { + this.V2references.exists(_.length > 1) + } } /** - * A filter that evaluates to `true` iff the attribute evaluates to a value + * A filter that evaluates to `true` iff the column evaluates to a value * equal to `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -58,60 +88,103 @@ case class EqualTo(attribute: String, value: Any) extends Filter { * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` * (rather than NULL) if one of the input is NULL and the other is not NULL. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.5.0 */ @Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than or equal to `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** * A filter that evaluates to `true` iff the attribute evaluates to a value * less than `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** * A filter that evaluates to `true` iff the attribute evaluates to a value * less than or equal to `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** * A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable @@ -134,26 +207,47 @@ case class In(attribute: String, values: Array[Any]) extends Filter { } override def references: Array[String] = Array(attribute) ++ values.flatMap(findReferences) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** * A filter that evaluates to `true` iff the attribute evaluates to null. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** * A filter that evaluates to `true` iff the attribute evaluates to a non-null value. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.0 */ @Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -190,33 +284,57 @@ case class Not(child: Filter) extends Filter { * A filter that evaluates to `true` iff the attribute evaluates to * a string that starts with `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.1 */ @Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** * A filter that evaluates to `true` iff the attribute evaluates to * a string that ends with `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.1 */ @Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** * A filter that evaluates to `true` iff the attribute evaluates to * a string that contains the string `value`. * + * @param attribute of the column to be evaluated; `dots` are used as separators + * for nested columns. If any part of the names contains `dots`, + * it is quoted to avoid confusion. * @since 1.3.1 */ @Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + + /** + * A column name as an array of string multi-identifier + */ + val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1641b660a271d..08700122e3f3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -652,10 +652,12 @@ object DataSourceStrategy { */ object PushableColumn { def unapply(e: Expression): Option[String] = { - def helper(e: Expression) = e match { - case a: Attribute => Some(a.name) + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + def helper(e: Expression): Option[Seq[String]] = e match { + case a: Attribute => Some(Seq(a.name)) + case s: GetStructField => helper(s.child).map(_ :+ s.childSchema(s.ordinal).name) case _ => None } - helper(e) + helper(e).map(_.quoted) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index e673309188756..aa23b50aa4f7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.execution.datasources.orc +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.{And, Filter} import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType} +import org.apache.spark.sql.types.StructType /** * Methods that can be shared when upgrading the built-in Hive. @@ -45,4 +47,11 @@ trait OrcFiltersBase { case _: AtomicType => true case _ => false } + + /** + * The key of the dataTypeMap will be quoted if it contains `dots`. + */ + protected[orc] def quotedDataTypeMap(schema: StructType): Map[String, DataType] = { + schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 07065018343cf..75af25301ad25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.SparkFilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} +import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveComparator, PrimitiveType, Type} import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ @@ -49,15 +49,34 @@ class ParquetFilters( pushDownInFilterThreshold: Int, caseSensitive: Boolean) { // A map which contains parquet field name and data type, if predicate push down applies. - private val nameToParquetField : Map[String, ParquetField] = { - // Here we don't flatten the fields in the nested schema but just look up through - // root fields. Currently, accessing to nested fields does not push down filters - // and it does not support to create filters for them. - val primitiveFields = - schema.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => - f.getName -> ParquetField(f.getName, - ParquetSchemaType(f.getOriginalType, - f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)) + // The keys are the column names. For nested column, `dot` will be used as a separator. + // For column name that contains `dot`, backquote will be used. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField : Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + case p: PrimitiveType => + Some(ParquetPrimitiveField(fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType(p.getOriginalType, + p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala).map { field => + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + (field.fieldNames.toSeq.quoted, field) } if (caseSensitive) { primitiveFields.toMap @@ -75,13 +94,13 @@ class ParquetFilters( } /** - * Holds a single field information stored in the underlying parquet file. + * Holds a single primitive field information stored in the underlying parquet file. * - * @param fieldName field name in parquet file + * @param fieldNames a field name as an array of string multi-identifier in parquet file * @param fieldType field type related info in parquet file */ - private case class ParquetField( - fieldName: String, + private case class ParquetPrimitiveField( + fieldNames: Array[String], fieldType: ParquetSchemaType) private case class ParquetSchemaType( @@ -163,7 +182,6 @@ class ParquetFilters( (n: Array[String], v: Any) => FilterApi.eq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) - case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => (n: Array[String], v: Any) => FilterApi.eq( intColumn(n), @@ -472,13 +490,8 @@ class ParquetFilters( case _ => false } - // Parquet does not allow dots in the column name because dots are used as a column path - // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates - // with missing columns. The incorrect results could be got from Parquet when we push down - // filters for the column having dots in the names. Thus, we do not push down such filters. - // See SPARK-20364. private def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) } /** @@ -509,38 +522,38 @@ class ParquetFilters( predicate match { case sources.IsNull(name) if canMakeFilterOn(name, null) => makeEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), null)) + .map(_(nameToParquetField(name).fieldNames, null)) case sources.IsNotNull(name) if canMakeFilterOn(name, null) => makeNotEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), null)) + .map(_(nameToParquetField(name).fieldNames, null)) case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.LessThan(name, value) if canMakeFilterOn(name, value) => makeLt.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeLtEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => makeGt.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeGtEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), value)) + .map(_(nameToParquetField(name).fieldNames, value)) case sources.And(lhs, rhs) => // At here, it is not safe to just convert one side and remove the other side @@ -591,13 +604,13 @@ class ParquetFilters( && values.distinct.length <= pushDownInFilterThreshold => values.distinct.flatMap { v => makeEq.lift(nameToParquetField(name).fieldType) - .map(_(Array(nameToParquetField(name).fieldName), v)) + .map(_(nameToParquetField(name).fieldNames, v)) }.reduceLeftOption(FilterApi.or) case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name, prefix) => Option(prefix).map { v => - FilterApi.userDefined(binaryColumn(Array(nameToParquetField(name).fieldName)), + FilterApi.userDefined(binaryColumn(nameToParquetField(name).fieldNames), new UserDefinedPredicate[Binary] with Serializable { private val strToBinary = Binary.fromReusedByteArray(v.getBytes) private val size = strToBinary.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 1421ffd8b6de4..9f40f5faa2e99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters @@ -59,8 +60,10 @@ case class OrcScanBuilder( // changed `hadoopConf` in executors. OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames) } - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap - _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. + val newFilters = filters.filter(!_.containsNestedColumn) + _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, newFilters).toArray } filters } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 7bd3213b378ce..a775a97895cfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -26,15 +26,61 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class DataSourceStrategySuite extends PlanTest with SharedSparkSession { val attrInts = Seq( - 'cint.int + 'cint.int, + Symbol("c.int").int, + GetStructField('a.struct(StructType( + StructField("cstr", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None), + GetStructField('a.struct(StructType( + StructField("c.int", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 0, None), + GetStructField(Symbol("a.b").struct(StructType( + StructField("cstr1", StringType, nullable = true) :: + StructField("cstr2", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None), + GetStructField(Symbol("a.b").struct(StructType( + StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None), + GetStructField(GetStructField('a.struct(StructType( + StructField("cstr1", StringType, nullable = true) :: + StructField("b", StructType(StructField("cint", IntegerType, nullable = true) :: + StructField("cstr2", StringType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) ).zip(Seq( - "cint" + "cint", + "`c.int`", // single level field that contains `dot` in name + "a.cint", // two level nested field + "a.`c.int`", // two level nested field, and nested level contains `dot` + "`a.b`.cint", // two level nested field, and top level contains `dot` + "`a.b`.`c.int`", // two level nested field, and both levels contain `dot` + "a.b.cint" // three level nested field )) val attrStrs = Seq( - 'cstr.string + 'cstr.string, + Symbol("c.str").string, + GetStructField('a.struct(StructType( + StructField("cint", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 1, None), + GetStructField('a.struct(StructType( + StructField("c.str", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None), + GetStructField(Symbol("a.b").struct(StructType( + StructField("cint1", IntegerType, nullable = true) :: + StructField("cint2", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 2, None), + GetStructField(Symbol("a.b").struct(StructType( + StructField("c.str", StringType, nullable = true) :: Nil)), 0, None), + GetStructField(GetStructField('a.struct(StructType( + StructField("cint1", IntegerType, nullable = true) :: + StructField("b", StructType(StructField("cstr", StringType, nullable = true) :: + StructField("cint2", IntegerType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) ).zip(Seq( - "cstr" + "cstr", + "`c.str`", // single level field that contains `dot` in name + "a.cstr", // two level nested field + "a.`c.str`", // two level nested field, and nested level contains `dot` + "`a.b`.cstr", // two level nested field, and top level contains `dot` + "`a.b`.`c.str`", // two level nested field, and both levels contain `dot` + "a.b.cstr" // three level nested field )) test("translate simple expression") { attrInts.zip(attrStrs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4e0c1c2dbe601..31d6784253a8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -62,6 +62,12 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSparkSession { + case class ColA[T](a: Option[T]) + + case class ColB[T](b: Option[T]) + + case class ColC[T](c: Option[T]) + protected def createParquetFilters( schema: MessageType, caseSensitive: Option[Boolean] = None): ParquetFilters = @@ -128,37 +134,47 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val ts3 = data(2) val ts4 = data(3) - withParquetDataFrame(data.map(i => Tuple1(i))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i))) - - checkFilterPredicate('_1 === ts1, classOf[Eq[_]], ts1) - checkFilterPredicate('_1 <=> ts1, classOf[Eq[_]], ts1) - checkFilterPredicate('_1 =!= ts1, classOf[NotEq[_]], - Seq(ts2, ts3, ts4).map(i => Row.apply(i))) - - checkFilterPredicate('_1 < ts2, classOf[Lt[_]], ts1) - checkFilterPredicate('_1 > ts1, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(i))) - checkFilterPredicate('_1 <= ts1, classOf[LtEq[_]], ts1) - checkFilterPredicate('_1 >= ts4, classOf[GtEq[_]], ts4) - - checkFilterPredicate(Literal(ts1) === '_1, classOf[Eq[_]], ts1) - checkFilterPredicate(Literal(ts1) <=> '_1, classOf[Eq[_]], ts1) - checkFilterPredicate(Literal(ts2) > '_1, classOf[Lt[_]], ts1) - checkFilterPredicate(Literal(ts3) < '_1, classOf[Gt[_]], ts4) - checkFilterPredicate(Literal(ts1) >= '_1, classOf[LtEq[_]], ts1) - checkFilterPredicate(Literal(ts4) <= '_1, classOf[GtEq[_]], ts4) - - checkFilterPredicate(!('_1 < ts4), classOf[GtEq[_]], ts4) - checkFilterPredicate('_1 < ts2 || '_1 > ts3, classOf[Operators.Or], Seq(Row(ts1), Row(ts4))) - } - } - - private def testDecimalPushDown(data: DataFrame)(f: DataFrame => Unit): Unit = { - withTempPath { file => - data.write.parquet(file.getCanonicalPath) - readParquetFile(file.toString)(f) - } + Seq( + ( + spark.createDataFrame(data.map(x => ColA(Some(x)))), + "a", // zero nesting + (x: Any) => x), + ( + spark.createDataFrame(data.map(x => ColA(Some(ColB(Some(x)))))), + "a.b", // one level nesting + (x: Any) => Row(x)), + ( + spark.createDataFrame(data.map(x => ColA(Some(ColB(Some(ColC(Some(x)))))))), + "a.b.c", // two level nesting + (x: Any) => Row(Row(x))) + ).foreach { case (i, pushDownColName, resultFun) => withParquetDFfromDF(i) { implicit df => + val tsAttr = df(pushDownColName).expr + checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]], + data.map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tsAttr === ts1, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr <=> ts1, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr =!= ts1, classOf[NotEq[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tsAttr < ts2, classOf[Lt[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr > ts1, classOf[Gt[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) + checkFilterPredicate(tsAttr <= ts1, classOf[LtEq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr >= ts4, classOf[GtEq[_]], resultFun(ts4)) + + checkFilterPredicate(Literal(ts1) === tsAttr, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts1) <=> tsAttr, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts2) > tsAttr, classOf[Lt[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts3) < tsAttr, classOf[Gt[_]], resultFun(ts4)) + checkFilterPredicate(Literal(ts1) >= tsAttr, classOf[LtEq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts4) <= tsAttr, classOf[GtEq[_]], resultFun(ts4)) + + checkFilterPredicate(!(tsAttr < ts4), classOf[GtEq[_]], resultFun(ts4)) + checkFilterPredicate(tsAttr < ts2 || tsAttr > ts3, classOf[Operators.Or], + Seq(Row(resultFun(ts1)), Row(resultFun(ts4)))) + }} } // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. @@ -187,18 +203,35 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - boolean") { - withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - - checkFilterPredicate('_1 === true, classOf[Eq[_]], true) - checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) - checkFilterPredicate('_1 =!= true, classOf[NotEq[_]], false) - } - } + val data = true :: false :: Nil + Seq( + ( + spark.sqlContext.createDataFrame(data.map(x => ColA(Option(x)))), + "a", // zero nesting + (x: Any) => x), + ( + spark.sqlContext.createDataFrame(data.map(x => ColA(Option(ColB(Option(x)))))), + "a.b", // one level nesting + (x: Any) => Row(x)), + ( + spark.sqlContext.createDataFrame( + data.map(x => ColA(Option(ColB(Option(ColC(Option(x)))))))), + "a.b.c", // two level nesting + (x: Any) => Row(Row(x))) + ).foreach { case (i, pushDownColName, resultFun) => withParquetDFfromDF(i) { implicit df => + val booleanAttr = df(pushDownColName).expr + checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]], + Seq(Row(resultFun(true)), Row(resultFun(false)))) + + checkFilterPredicate(booleanAttr === true, classOf[Eq[_]], resultFun(true)) + checkFilterPredicate(booleanAttr <=> true, classOf[Eq[_]], resultFun(true)) + checkFilterPredicate(booleanAttr =!= true, classOf[NotEq[_]], resultFun(false)) + } + }} test("filter pushdown - tinyint") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => assert(df.schema.head.dataType === ByteType) checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -226,7 +259,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - smallint") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => assert(df.schema.head.dataType === ShortType) checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -254,7 +287,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - integer") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -280,7 +313,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - long") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -306,7 +339,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - float") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -332,7 +365,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - double") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -358,7 +391,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - string") { - withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(i.toString))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) @@ -390,7 +423,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) } - withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(i.b))) { implicit df => checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b) @@ -426,7 +459,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21") - withParquetDataFrame(data.map(i => Tuple1(i.date))) { implicit df => + withParquetDFfromObjs(data.map(i => Tuple1(i.date))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date))) @@ -485,7 +518,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> ParquetOutputTimestampType.INT96.toString) { - withParquetDataFrame(millisData.map(i => Tuple1(i))) { implicit df => + withParquetDFfromObjs(millisData.map(i => Tuple1(i))) { implicit df => val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) assertResult(None) { createParquetFilters(schema).createFilter(sources.IsNull("_1")) @@ -506,7 +539,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val rdd = spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) val dataFrame = spark.createDataFrame(rdd, schema) - testDecimalPushDown(dataFrame) { implicit df => + withParquetDFfromDF(dataFrame) { implicit df => assert(df.schema === schema) checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -1042,7 +1075,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("SPARK-16371 Do not push down filters when inner name and outer name are the same") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i)))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(Tuple1(i)))) { implicit df => // Here the schema becomes as below: // // root @@ -1107,7 +1140,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } - test("SPARK-20364: Disable Parquet predicate pushdown for fields having dots in the names") { + test("SPARK-31026: Parquet predicate pushdown for fields having dots in the names") { import testImplicits._ Seq(true, false).foreach { vectorized => @@ -1120,6 +1153,28 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared assert(readBack.count() == 1) } } + + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString, + // Makes sure disabling 'spark.sql.parquet.recordFilter' still enables + // row group level filtering. + SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> "false", + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + + withTempPath { path => + val data = (1 to 1024) + data.toDF("col.dots").coalesce(1) + .write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath).filter("`col.dots` == 500") + // Here, we strip the Spark side filter and check the actual results from Parquet. + val actual = stripSparkFilter(df).collect().length + // Since those are filtered at row group level, the result count should be less + // than the total length but should not be a single record. + // Note that, if record level filtering is enabled, it should be a single record. + // If no filter is pushed down to Parquet, it should be the total length of data. + assert(actual > 1 && actual < data.length) + } + } } } @@ -1162,7 +1217,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - StringStartsWith") { - withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => + withParquetDFfromObjs((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => checkFilterPredicate( '_1.startsWith("").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], @@ -1208,7 +1263,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } // SPARK-28371: make sure filter is null-safe. - withParquetDataFrame(Seq(Tuple1[String](null))) { implicit df => + withParquetDFfromObjs(Seq(Tuple1[String](null))) { implicit df => checkFilterPredicate( '_1.startsWith("blah").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 7f85fd2a1629a..3e614e75f6033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -82,7 +82,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession * Writes `data` to a Parquet file, reads it back and check file contents. */ protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple))) + withParquetDFfromObjs(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } test("basic data types (without binary)") { @@ -94,7 +94,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession test("raw binary") { val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetDataFrame(data) { df => + withParquetDFfromObjs(data) { df => assertResult(data.map(_._1.mkString(",")).sorted) { df.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted } @@ -197,7 +197,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession testStandardAndLegacyModes("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetDataFrame(data) { df => + withParquetDFfromObjs(data) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) @@ -214,7 +214,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDataFrame(data) { df => + withParquetDFfromObjs(data) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(array) => Row(array.map(struct => Row(struct.productIterator.toSeq: _*))) @@ -233,7 +233,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDataFrame(data) { df => + withParquetDFfromObjs(data) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(array) => Row(array.map { case Tuple1(Tuple1(str)) => Row(Row(str))}) @@ -243,7 +243,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession testStandardAndLegacyModes("nested struct with array of array as field") { val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetDataFrame(data) { df => + withParquetDFfromObjs(data) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) @@ -260,7 +260,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDataFrame(data) { df => + withParquetDFfromObjs(data) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(m) => Row(m.map { case (k, v) => Row(k.productIterator.toSeq: _*) -> v }) @@ -277,7 +277,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDataFrame(data) { df => + withParquetDFfromObjs(data) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(m) => Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) @@ -293,7 +293,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession null.asInstanceOf[java.lang.Float], null.asInstanceOf[java.lang.Double]) - withParquetDataFrame(allNulls :: Nil) { df => + withParquetDFfromObjs(allNulls :: Nil) { df => val rows = df.collect() assert(rows.length === 1) assert(rows.head === Row(Seq.fill(5)(null): _*)) @@ -306,7 +306,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession None.asInstanceOf[Option[Long]], None.asInstanceOf[Option[String]]) - withParquetDataFrame(allNones :: Nil) { df => + withParquetDFfromObjs(allNones :: Nil) { df => val rows = df.collect() assert(rows.length === 1) assert(rows.head === Row(Seq.fill(3)(null): _*)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 828ba6aee026b..8d175017fba8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -63,13 +63,38 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { (f: String => Unit): Unit = withDataSourceFile(data)(f) /** - * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], + * Writes `data` objects to a Parquet file and reads it back as a [[DataFrame]], * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ - protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] + protected def withParquetDFfromObjs[T <: Product: ClassTag: TypeTag] (data: Seq[T], testVectorized: Boolean = true) (f: DataFrame => Unit): Unit = withDataSourceDataFrame(data, testVectorized)(f) + /** + * Writes `df` dataframe to a Parquet file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Parquet file will be deleted after `f` returns. + */ + protected def withParquetDFfromDF[T <: Product: ClassTag: TypeTag] + (df: DataFrame, testVectorized: Boolean = true) + (f: DataFrame => Unit): Unit = { + withTempPath { file => + df.write.format(dataSourceName).save(file.getCanonicalPath) + readFile(file.getCanonicalPath, testVectorized)(f) + } + } + + /** + * Writes `df` to a Parquet file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Parquet file will be deleted after `f` returns. + */ + protected def toParquetDataFrame(df: DataFrame, testVectorized: Boolean = true) + (f: DataFrame => Unit): Unit = { + withTempPath { file => + df.write.format(dataSourceName).save(file.getCanonicalPath) + readFile(file.getCanonicalPath, testVectorized)(f) + } + } + /** * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a * temporary table named `tableName`, then call `f`. The temporary table together with the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala index 1cb7a2156c3d3..bf28acaa65bb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala @@ -24,66 +24,155 @@ import org.apache.spark.SparkFunSuite */ class FiltersSuite extends SparkFunSuite { - test("EqualTo references") { - assert(EqualTo("a", "1").references.toSeq == Seq("a")) - assert(EqualTo("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) + private def withFieldNames(f: (String, Array[String]) => Unit): Unit = { + Seq(("a", Array("a")), + ("a.b", Array("a", "b")), + ("`a.b`.c", Array("a.b", "c")), + ("`a.b`.`c.d`.`e.f`", Array("a.b", "c.d", "e.f")) + ).foreach { case (name, fieldNames) => + f(name, fieldNames) + } } - test("EqualNullSafe references") { - assert(EqualNullSafe("a", "1").references.toSeq == Seq("a")) - assert(EqualNullSafe("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } - - test("GreaterThan references") { - assert(GreaterThan("a", "1").references.toSeq == Seq("a")) - assert(GreaterThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } + test("EqualTo references") { withFieldNames { (name, fieldNames) => + assert(EqualTo(name, "1").references.toSeq == Seq(name)) + assert(EqualTo(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(EqualTo(name, "1").fieldNames.toSeq == fieldNames.toSeq) + + assert(EqualTo(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(EqualTo("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) + + assert(EqualTo(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(EqualTo("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("EqualNullSafe references") { withFieldNames { (name, fieldNames) => + assert(EqualNullSafe(name, "1").references.toSeq == Seq(name)) + assert(EqualNullSafe(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(EqualNullSafe(name, "1").fieldNames.toSeq == fieldNames.toSeq) + + assert(EqualNullSafe(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(EqualNullSafe("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) + + assert(EqualNullSafe(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(EqualNullSafe("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("GreaterThan references") { withFieldNames { (name, fieldNames) => + assert(GreaterThan(name, "1").references.toSeq == Seq(name)) + assert(GreaterThan(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(GreaterThan(name, "1").fieldNames.toSeq == fieldNames.toSeq) + + assert(GreaterThan(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(GreaterThan("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) + + assert(GreaterThan(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(GreaterThan("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("GreaterThanOrEqual references") { withFieldNames { (name, fieldNames) => + assert(GreaterThanOrEqual(name, "1").references.toSeq == Seq(name)) + assert(GreaterThanOrEqual(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(GreaterThanOrEqual(name, "1").fieldNames.toSeq == fieldNames.toSeq) + + assert(GreaterThanOrEqual(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(GreaterThanOrEqual("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) + + assert(GreaterThanOrEqual(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(GreaterThanOrEqual("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("LessThan references") { withFieldNames { (name, fieldNames) => + assert(LessThan(name, "1").references.toSeq == Seq(name)) + assert(LessThan(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(LessThan(name, "1").fieldNames.toSeq == fieldNames.toSeq) - test("GreaterThanOrEqual references") { - assert(GreaterThanOrEqual("a", "1").references.toSeq == Seq("a")) - assert(GreaterThanOrEqual("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } - - test("LessThan references") { - assert(LessThan("a", "1").references.toSeq == Seq("a")) assert(LessThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } - - test("LessThanOrEqual references") { - assert(LessThanOrEqual("a", "1").references.toSeq == Seq("a")) - assert(LessThanOrEqual("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) - } - - test("In references") { - assert(In("a", Array("1")).references.toSeq == Seq("a")) - assert(In("a", Array("1", EqualTo("b", "2"))).references.toSeq == Seq("a", "b")) - } - - test("IsNull references") { - assert(IsNull("a").references.toSeq == Seq("a")) - } - - test("IsNotNull references") { - assert(IsNotNull("a").references.toSeq == Seq("a")) - } - - test("And references") { - assert(And(EqualTo("a", "1"), EqualTo("b", "1")).references.toSeq == Seq("a", "b")) - } - - test("Or references") { - assert(Or(EqualTo("a", "1"), EqualTo("b", "1")).references.toSeq == Seq("a", "b")) - } - - test("StringStartsWith references") { - assert(StringStartsWith("a", "str").references.toSeq == Seq("a")) - } - - test("StringEndsWith references") { - assert(StringEndsWith("a", "str").references.toSeq == Seq("a")) - } - - test("StringContains references") { - assert(StringContains("a", "str").references.toSeq == Seq("a")) - } + }} + + test("LessThanOrEqual references") { withFieldNames { (name, fieldNames) => + assert(LessThanOrEqual(name, "1").references.toSeq == Seq(name)) + assert(LessThanOrEqual(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(LessThanOrEqual(name, "1").fieldNames.toSeq == fieldNames.toSeq) + + assert(LessThanOrEqual(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) + assert(LessThanOrEqual("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) + + assert(LessThanOrEqual(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(LessThanOrEqual("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("In references") { withFieldNames { (name, fieldNames) => + assert(In(name, Array("1")).references.toSeq == Seq(name)) + assert(In(name, Array("1")).V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(In(name, Array("1")).fieldNames.toSeq == fieldNames.toSeq) + + assert(In(name, Array("1", EqualTo("b", "2"))).references.toSeq == Seq(name, "b")) + assert(In("b", Array("1", EqualTo(name, "2"))).references.toSeq == Seq("b", name)) + + assert(In(name, Array("1", EqualTo("b", "2"))).V2references.toSeq.map(_.toSeq) + == Seq(fieldNames.toSeq, Seq("b"))) + assert(In("b", Array("1", EqualTo(name, "2"))).V2references.toSeq.map(_.toSeq) + == Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("IsNull references") { withFieldNames { (name, fieldNames) => + assert(IsNull(name).references.toSeq == Seq(name)) + assert(IsNull(name).V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(IsNull(name).fieldNames.toSeq == fieldNames.toSeq) + }} + + test("IsNotNull references") { withFieldNames { (name, fieldNames) => + assert(IsNotNull(name).references.toSeq == Seq(name)) + assert(IsNull(name).V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(IsNull(name).fieldNames.toSeq == fieldNames.toSeq) + }} + + test("And references") { withFieldNames { (name, fieldNames) => + assert(And(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq == Seq(name, "b")) + assert(And(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq == Seq("b", name)) + + assert(And(EqualTo(name, "1"), EqualTo("b", "1")).V2references.toSeq.map(_.toSeq) == + Seq(fieldNames.toSeq, Seq("b"))) + assert(And(EqualTo("b", "1"), EqualTo(name, "1")).V2references.toSeq.map(_.toSeq) == + Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("Or references") { withFieldNames { (name, fieldNames) => + assert(Or(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq == Seq(name, "b")) + assert(Or(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq == Seq("b", name)) + + assert(Or(EqualTo(name, "1"), EqualTo("b", "1")).V2references.toSeq.map(_.toSeq) == + Seq(fieldNames.toSeq, Seq("b"))) + assert(Or(EqualTo("b", "1"), EqualTo(name, "1")).V2references.toSeq.map(_.toSeq) == + Seq(Seq("b"), fieldNames.toSeq)) + }} + + test("StringStartsWith references") { withFieldNames { (name, fieldNames) => + assert(StringStartsWith(name, "str").references.toSeq == Seq(name)) + assert(StringStartsWith(name, "str").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(StringStartsWith(name, "str").fieldNames.toSeq == fieldNames.toSeq) + }} + + test("StringEndsWith references") { withFieldNames { (name, fieldNames) => + assert(StringEndsWith(name, "str").references.toSeq == Seq(name)) + assert(StringEndsWith(name, "str").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(StringEndsWith(name, "str").fieldNames.toSeq == fieldNames.toSeq) + }} + + test("StringContains references") { withFieldNames { (name, fieldNames) => + assert(StringContains(name, "str").references.toSeq == Seq(name)) + assert(StringContains(name, "str").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(StringContains(name, "str").fieldNames.toSeq == fieldNames.toSeq) + }} } diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index b9cbc484e1fc1..6e404cabfa7a9 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -65,9 +65,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = quotedDataTypeMap(schema) // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) + // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. + val newFilters = filters.filter(!_.containsNestedColumn) + val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the @@ -222,48 +224,39 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().equals(name, getType(name), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + case IsNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startAnd().isNull(name, getType(name)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startNot().isNull(name, getType(name)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) - Some(builder.startAnd().in(quotedName, getType(attribute), + case In(name, values) if isSearchableType(dataTypeMap(name)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) + Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 6e9e592be13be..4f0bc415a174a 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -24,7 +24,6 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -65,9 +64,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = quotedDataTypeMap(schema) // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) + // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. + val newFilters = filters.filter(!_.containsNestedColumn) + val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the @@ -222,48 +223,39 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().equals(name, getType(name), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + case IsNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startAnd().isNull(name, getType(name)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startNot().isNull(name, getType(name)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) - Some(builder.startAnd().in(quotedName, getType(attribute), + case In(name, values) if isSearchableType(dataTypeMap(name)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) + Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index cd1bffb6b7ab7..f9c514567c639 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.SparkException import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.execution.datasources.orc.{OrcFilters => DatasourceOrcFilters} import org.apache.spark.sql.execution.datasources.orc.OrcFilters.buildTree import org.apache.spark.sql.hive.HiveUtils @@ -73,9 +74,11 @@ private[orc] object OrcFilters extends Logging { if (HiveUtils.isHive23) { DatasourceOrcFilters.createFilter(schema, filters).asInstanceOf[Option[SearchArgument]] } else { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. + val newFilters = filters.filter(!_.containsNestedColumn) // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) + val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the From 22a7a0029943a696dc0f085972d8568473c24ec5 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 23 Mar 2020 17:19:33 -0700 Subject: [PATCH 2/9] address feedback --- .../apache/spark/sql/internal/SQLConf.scala | 13 ++++ .../apache/spark/sql/sources/filters.scala | 60 ------------------- .../datasources/DataSourceStrategy.scala | 11 +++- .../datasources/parquet/ParquetFilters.scala | 5 +- .../parquet/ParquetFilterSuite.scala | 32 +++++----- .../datasources/parquet/ParquetIOSuite.scala | 20 +++---- .../datasources/parquet/ParquetTest.scala | 25 ++------ .../spark/sql/sources/FiltersSuite.scala | 12 ---- 8 files changed, 55 insertions(+), 123 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9a524defb2816..a0e991ff1fd1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2049,6 +2049,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val NESTED_PREDICATE_PUSHDOWN_ENABLED = + buildConf("spark.sql.optimizer.nestedPredicatePushdown.enabled") + .internal() + .doc("When true, Spark tries to push down predicates for nested columns and or names " + + "containing `dots` to data sources. Currently, Parquet implements both optimizations " + + "while ORC only supports predicates for names containing `dots`. The other data sources" + + "don't support this feature yet.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED = buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled") .internal() @@ -3035,6 +3046,8 @@ class SQLConf extends Serializable with Logging { def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED) + def nestedPredicatePushdownEnabled: Boolean = getConf(NESTED_PREDICATE_PUSHDOWN_ENABLED) + def serializerNestedSchemaPruningEnabled: Boolean = getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 570219a9cdaff..8687be58dae80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -76,11 +76,6 @@ sealed abstract class Filter { @Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -96,11 +91,6 @@ case class EqualTo(attribute: String, value: Any) extends Filter { @Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -115,11 +105,6 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { @Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -134,11 +119,6 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { @Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -153,11 +133,6 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { @Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -172,11 +147,6 @@ case class LessThan(attribute: String, value: Any) extends Filter { @Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -207,11 +177,6 @@ case class In(attribute: String, values: Array[Any]) extends Filter { } override def references: Array[String] = Array(attribute) ++ values.flatMap(findReferences) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -225,11 +190,6 @@ case class In(attribute: String, values: Array[Any]) extends Filter { @Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -243,11 +203,6 @@ case class IsNull(attribute: String) extends Filter { @Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -292,11 +247,6 @@ case class Not(child: Filter) extends Filter { @Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -311,11 +261,6 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { @Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** @@ -330,11 +275,6 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { @Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) - - /** - * A column name as an array of string multi-identifier - */ - val fieldNames: Array[String] = parseColumnPath(attribute).toArray } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 08700122e3f3e..faf37609ad814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -652,10 +652,17 @@ object DataSourceStrategy { */ object PushableColumn { def unapply(e: Expression): Option[String] = { + val nestedPredicatePushdownEnabled = SQLConf.get.nestedPredicatePushdownEnabled import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper def helper(e: Expression): Option[Seq[String]] = e match { - case a: Attribute => Some(Seq(a.name)) - case s: GetStructField => helper(s.child).map(_ :+ s.childSchema(s.ordinal).name) + case a: Attribute => + if (nestedPredicatePushdownEnabled || !a.name.contains(".")) { + Some(Seq(a.name)) + } else { + None + } + case s: GetStructField if nestedPredicatePushdownEnabled => + helper(s.child).map(_ :+ s.childSchema(s.ordinal).name) case _ => None } helper(e).map(_.quoted) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 75af25301ad25..65731fbb2590d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -49,8 +49,9 @@ class ParquetFilters( pushDownInFilterThreshold: Int, caseSensitive: Boolean) { // A map which contains parquet field name and data type, if predicate push down applies. - // The keys are the column names. For nested column, `dot` will be used as a separator. - // For column name that contains `dot`, backquote will be used. + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. private val nameToParquetField : Map[String, ParquetPrimitiveField] = { // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 31d6784253a8b..97de50a521aef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -147,7 +147,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared spark.createDataFrame(data.map(x => ColA(Some(ColB(Some(ColC(Some(x)))))))), "a.b.c", // two level nesting (x: Any) => Row(Row(x))) - ).foreach { case (i, pushDownColName, resultFun) => withParquetDFfromDF(i) { implicit df => + ).foreach { case (i, pushDownColName, resultFun) => withParquetDataFrame(i) { implicit df => val tsAttr = df(pushDownColName).expr checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]], @@ -218,7 +218,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared data.map(x => ColA(Option(ColB(Option(ColC(Option(x)))))))), "a.b.c", // two level nesting (x: Any) => Row(Row(x))) - ).foreach { case (i, pushDownColName, resultFun) => withParquetDFfromDF(i) { implicit df => + ).foreach { case (i, pushDownColName, resultFun) => withParquetDataFrame(i) { implicit df => val booleanAttr = df(pushDownColName).expr checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]], @@ -231,7 +231,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared }} test("filter pushdown - tinyint") { - withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toByte))))) { implicit df => assert(df.schema.head.dataType === ByteType) checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -259,7 +259,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - smallint") { - withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toShort))))) { implicit df => assert(df.schema.head.dataType === ShortType) checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -287,7 +287,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - integer") { - withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i))))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -313,7 +313,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - long") { - withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toLong))))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -339,7 +339,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - float") { - withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toFloat))))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -365,7 +365,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - double") { - withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toDouble))))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -391,7 +391,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - string") { - withParquetDFfromObjs((1 to 4).map(i => Tuple1(i.toString))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(i.toString)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) @@ -423,7 +423,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) } - withParquetDFfromObjs((1 to 4).map(i => Tuple1(i.b))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(i.b)))) { implicit df => checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b) @@ -459,7 +459,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21") - withParquetDFfromObjs(data.map(i => Tuple1(i.date))) { implicit df => + withParquetDataFrame(toDF(data.map(i => Tuple1(i.date)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date))) @@ -518,7 +518,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> ParquetOutputTimestampType.INT96.toString) { - withParquetDFfromObjs(millisData.map(i => Tuple1(i))) { implicit df => + withParquetDataFrame(toDF(millisData.map(i => Tuple1(i)))) { implicit df => val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) assertResult(None) { createParquetFilters(schema).createFilter(sources.IsNull("_1")) @@ -539,7 +539,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val rdd = spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) val dataFrame = spark.createDataFrame(rdd, schema) - withParquetDFfromDF(dataFrame) { implicit df => + withParquetDataFrame(dataFrame) { implicit df => assert(df.schema === schema) checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -1075,7 +1075,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("SPARK-16371 Do not push down filters when inner name and outer name are the same") { - withParquetDFfromObjs((1 to 4).map(i => Tuple1(Tuple1(i)))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Tuple1(i))))) { implicit df => // Here the schema becomes as below: // // root @@ -1217,7 +1217,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - StringStartsWith") { - withParquetDFfromObjs((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => + withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(i + "str" + i)))) { implicit df => checkFilterPredicate( '_1.startsWith("").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], @@ -1263,7 +1263,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } // SPARK-28371: make sure filter is null-safe. - withParquetDFfromObjs(Seq(Tuple1[String](null))) { implicit df => + withParquetDataFrame(toDF(Seq(Tuple1[String](null)))) { implicit df => checkFilterPredicate( '_1.startsWith("blah").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 3e614e75f6033..497b823868450 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -82,7 +82,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession * Writes `data` to a Parquet file, reads it back and check file contents. */ protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetDFfromObjs(data)(r => checkAnswer(r, data.map(Row.fromTuple))) + withParquetDataFrame(data.toDF())(r => checkAnswer(r, data.map(Row.fromTuple))) } test("basic data types (without binary)") { @@ -94,7 +94,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession test("raw binary") { val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetDFfromObjs(data) { df => + withParquetDataFrame(data.toDF()) { df => assertResult(data.map(_._1.mkString(",")).sorted) { df.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted } @@ -197,7 +197,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession testStandardAndLegacyModes("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetDFfromObjs(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) @@ -214,7 +214,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDFfromObjs(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(array) => Row(array.map(struct => Row(struct.productIterator.toSeq: _*))) @@ -233,7 +233,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDFfromObjs(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(array) => Row(array.map { case Tuple1(Tuple1(str)) => Row(Row(str))}) @@ -243,7 +243,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession testStandardAndLegacyModes("nested struct with array of array as field") { val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetDFfromObjs(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) @@ -260,7 +260,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDFfromObjs(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(m) => Row(m.map { case (k, v) => Row(k.productIterator.toSeq: _*) -> v }) @@ -277,7 +277,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession ) ) } - withParquetDFfromObjs(data) { df => + withParquetDataFrame(data.toDF()) { df => // Structs are converted to `Row`s checkAnswer(df, data.map { case Tuple1(m) => Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) @@ -293,7 +293,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession null.asInstanceOf[java.lang.Float], null.asInstanceOf[java.lang.Double]) - withParquetDFfromObjs(allNulls :: Nil) { df => + withParquetDataFrame((allNulls :: Nil).toDF()) { df => val rows = df.collect() assert(rows.length === 1) assert(rows.head === Row(Seq.fill(5)(null): _*)) @@ -306,7 +306,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession None.asInstanceOf[Option[Long]], None.asInstanceOf[Option[String]]) - withParquetDFfromObjs(allNones :: Nil) { df => + withParquetDataFrame((allNones :: Nil).toDF()) { df => val rows = df.collect() assert(rows.length === 1) assert(rows.head === Row(Seq.fill(3)(null): _*)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 8d175017fba8a..6e5c562571997 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -62,32 +62,15 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { (data: Seq[T]) (f: String => Unit): Unit = withDataSourceFile(data)(f) - /** - * Writes `data` objects to a Parquet file and reads it back as a [[DataFrame]], - * which is then passed to `f`. The Parquet file will be deleted after `f` returns. - */ - protected def withParquetDFfromObjs[T <: Product: ClassTag: TypeTag] - (data: Seq[T], testVectorized: Boolean = true) - (f: DataFrame => Unit): Unit = withDataSourceDataFrame(data, testVectorized)(f) - - /** - * Writes `df` dataframe to a Parquet file and reads it back as a [[DataFrame]], - * which is then passed to `f`. The Parquet file will be deleted after `f` returns. - */ - protected def withParquetDFfromDF[T <: Product: ClassTag: TypeTag] - (df: DataFrame, testVectorized: Boolean = true) - (f: DataFrame => Unit): Unit = { - withTempPath { file => - df.write.format(dataSourceName).save(file.getCanonicalPath) - readFile(file.getCanonicalPath, testVectorized)(f) - } + protected def toDF[T <: Product: ClassTag: TypeTag](data: Seq[T]): DataFrame = { + spark.createDataFrame(data) } /** - * Writes `df` to a Parquet file and reads it back as a [[DataFrame]], + * Writes `df` dataframe to a Parquet file and reads it back as a [[DataFrame]], * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ - protected def toParquetDataFrame(df: DataFrame, testVectorized: Boolean = true) + protected def withParquetDataFrame(df: DataFrame, testVectorized: Boolean = true) (f: DataFrame => Unit): Unit = { withTempPath { file => df.write.format(dataSourceName).save(file.getCanonicalPath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala index bf28acaa65bb0..6629e3dfa0c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala @@ -37,7 +37,6 @@ class FiltersSuite extends SparkFunSuite { test("EqualTo references") { withFieldNames { (name, fieldNames) => assert(EqualTo(name, "1").references.toSeq == Seq(name)) assert(EqualTo(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(EqualTo(name, "1").fieldNames.toSeq == fieldNames.toSeq) assert(EqualTo(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(EqualTo("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) @@ -51,7 +50,6 @@ class FiltersSuite extends SparkFunSuite { test("EqualNullSafe references") { withFieldNames { (name, fieldNames) => assert(EqualNullSafe(name, "1").references.toSeq == Seq(name)) assert(EqualNullSafe(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(EqualNullSafe(name, "1").fieldNames.toSeq == fieldNames.toSeq) assert(EqualNullSafe(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(EqualNullSafe("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) @@ -65,7 +63,6 @@ class FiltersSuite extends SparkFunSuite { test("GreaterThan references") { withFieldNames { (name, fieldNames) => assert(GreaterThan(name, "1").references.toSeq == Seq(name)) assert(GreaterThan(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(GreaterThan(name, "1").fieldNames.toSeq == fieldNames.toSeq) assert(GreaterThan(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(GreaterThan("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) @@ -79,7 +76,6 @@ class FiltersSuite extends SparkFunSuite { test("GreaterThanOrEqual references") { withFieldNames { (name, fieldNames) => assert(GreaterThanOrEqual(name, "1").references.toSeq == Seq(name)) assert(GreaterThanOrEqual(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(GreaterThanOrEqual(name, "1").fieldNames.toSeq == fieldNames.toSeq) assert(GreaterThanOrEqual(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(GreaterThanOrEqual("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) @@ -93,7 +89,6 @@ class FiltersSuite extends SparkFunSuite { test("LessThan references") { withFieldNames { (name, fieldNames) => assert(LessThan(name, "1").references.toSeq == Seq(name)) assert(LessThan(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(LessThan(name, "1").fieldNames.toSeq == fieldNames.toSeq) assert(LessThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) }} @@ -101,7 +96,6 @@ class FiltersSuite extends SparkFunSuite { test("LessThanOrEqual references") { withFieldNames { (name, fieldNames) => assert(LessThanOrEqual(name, "1").references.toSeq == Seq(name)) assert(LessThanOrEqual(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(LessThanOrEqual(name, "1").fieldNames.toSeq == fieldNames.toSeq) assert(LessThanOrEqual(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(LessThanOrEqual("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) @@ -115,7 +109,6 @@ class FiltersSuite extends SparkFunSuite { test("In references") { withFieldNames { (name, fieldNames) => assert(In(name, Array("1")).references.toSeq == Seq(name)) assert(In(name, Array("1")).V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(In(name, Array("1")).fieldNames.toSeq == fieldNames.toSeq) assert(In(name, Array("1", EqualTo("b", "2"))).references.toSeq == Seq(name, "b")) assert(In("b", Array("1", EqualTo(name, "2"))).references.toSeq == Seq("b", name)) @@ -129,13 +122,11 @@ class FiltersSuite extends SparkFunSuite { test("IsNull references") { withFieldNames { (name, fieldNames) => assert(IsNull(name).references.toSeq == Seq(name)) assert(IsNull(name).V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(IsNull(name).fieldNames.toSeq == fieldNames.toSeq) }} test("IsNotNull references") { withFieldNames { (name, fieldNames) => assert(IsNotNull(name).references.toSeq == Seq(name)) assert(IsNull(name).V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(IsNull(name).fieldNames.toSeq == fieldNames.toSeq) }} test("And references") { withFieldNames { (name, fieldNames) => @@ -161,18 +152,15 @@ class FiltersSuite extends SparkFunSuite { test("StringStartsWith references") { withFieldNames { (name, fieldNames) => assert(StringStartsWith(name, "str").references.toSeq == Seq(name)) assert(StringStartsWith(name, "str").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(StringStartsWith(name, "str").fieldNames.toSeq == fieldNames.toSeq) }} test("StringEndsWith references") { withFieldNames { (name, fieldNames) => assert(StringEndsWith(name, "str").references.toSeq == Seq(name)) assert(StringEndsWith(name, "str").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(StringEndsWith(name, "str").fieldNames.toSeq == fieldNames.toSeq) }} test("StringContains references") { withFieldNames { (name, fieldNames) => assert(StringContains(name, "str").references.toSeq == Seq(name)) assert(StringContains(name, "str").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) - assert(StringContains(name, "str").fieldNames.toSeq == fieldNames.toSeq) }} } From b775e02e016ad70271ff4d6090d907a65aadc318 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 24 Mar 2020 21:14:16 -0700 Subject: [PATCH 3/9] add more test cases --- .../datasources/orc/OrcFiltersBase.scala | 9 - .../parquet/ParquetFilterSuite.scala | 201 ++++++++++-------- .../datasources/orc/OrcFilters.scala | 2 +- .../datasources/orc/OrcFilters.scala | 4 +- 4 files changed, 111 insertions(+), 105 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index aa23b50aa4f7a..e673309188756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.execution.datasources.orc -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.{And, Filter} import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType} -import org.apache.spark.sql.types.StructType /** * Methods that can be shared when upgrading the built-in Hive. @@ -47,11 +45,4 @@ trait OrcFiltersBase { case _: AtomicType => true case _ => false } - - /** - * The key of the dataTypeMap will be quoted if it contains `dots`. - */ - protected[orc] def quotedDataTypeMap(schema: StructType): Map[String, DataType] = { - schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 97de50a521aef..f040ed9d985f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -62,12 +62,6 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSparkSession { - case class ColA[T](a: Option[T]) - - case class ColB[T](b: Option[T]) - - case class ColC[T](c: Option[T]) - protected def createParquetFilters( schema: MessageType, caseSensitive: Option[Boolean] = None): ParquetFilters = @@ -127,54 +121,80 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } - private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { - assert(data.size === 4) - val ts1 = data.head - val ts2 = data(1) - val ts3 = data(2) - val ts4 = data(3) - + /** + * Takes single level `inputDF` dataframe to generate multi-level nested + * dataframes as new test data. + */ + private def withNestedDataFrame(inputDF: DataFrame) + (testCases: (DataFrame, String, Any => Any) => Unit): Unit = { + val df = inputDF.toDF("temp") Seq( ( - spark.createDataFrame(data.map(x => ColA(Some(x)))), + df.withColumnRenamed("temp", "a"), "a", // zero nesting (x: Any) => x), ( - spark.createDataFrame(data.map(x => ColA(Some(ColB(Some(x)))))), + df.withColumn("a", struct(df("temp") as "b")).drop("temp"), "a.b", // one level nesting (x: Any) => Row(x)), ( - spark.createDataFrame(data.map(x => ColA(Some(ColB(Some(ColC(Some(x)))))))), + df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"), "a.b.c", // two level nesting - (x: Any) => Row(Row(x))) - ).foreach { case (i, pushDownColName, resultFun) => withParquetDataFrame(i) { implicit df => - val tsAttr = df(pushDownColName).expr - checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]], - data.map(i => Row.apply(resultFun(i)))) - - checkFilterPredicate(tsAttr === ts1, classOf[Eq[_]], resultFun(ts1)) - checkFilterPredicate(tsAttr <=> ts1, classOf[Eq[_]], resultFun(ts1)) - checkFilterPredicate(tsAttr =!= ts1, classOf[NotEq[_]], - Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) - - checkFilterPredicate(tsAttr < ts2, classOf[Lt[_]], resultFun(ts1)) - checkFilterPredicate(tsAttr > ts1, classOf[Gt[_]], - Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(tsAttr <= ts1, classOf[LtEq[_]], resultFun(ts1)) - checkFilterPredicate(tsAttr >= ts4, classOf[GtEq[_]], resultFun(ts4)) - - checkFilterPredicate(Literal(ts1) === tsAttr, classOf[Eq[_]], resultFun(ts1)) - checkFilterPredicate(Literal(ts1) <=> tsAttr, classOf[Eq[_]], resultFun(ts1)) - checkFilterPredicate(Literal(ts2) > tsAttr, classOf[Lt[_]], resultFun(ts1)) - checkFilterPredicate(Literal(ts3) < tsAttr, classOf[Gt[_]], resultFun(ts4)) - checkFilterPredicate(Literal(ts1) >= tsAttr, classOf[LtEq[_]], resultFun(ts1)) - checkFilterPredicate(Literal(ts4) <= tsAttr, classOf[GtEq[_]], resultFun(ts4)) - - checkFilterPredicate(!(tsAttr < ts4), classOf[GtEq[_]], resultFun(ts4)) - checkFilterPredicate(tsAttr < ts2 || tsAttr > ts3, classOf[Operators.Or], - Seq(Row(resultFun(ts1)), Row(resultFun(ts4)))) - }} + (x: Any) => Row(Row(x)) + ), + ( + df.withColumnRenamed("temp", "a.b"), + "`a.b`", // zero nesting with column name containing `dots` + (x: Any) => x + ), + ( + df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"), + "`a.b`.`c.d`", // one level nesting with column names containing `dots` + (x: Any) => Row(x) + ) + ).foreach { case (df, pushDownColName, resultTransFun) => + testCases(df, pushDownColName, resultTransFun) + } + } + + private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { + assert(data.size === 4) + val ts1 = data.head + val ts2 = data(1) + val ts3 = data(2) + val ts4 = data(3) + + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => + withParquetDataFrame(df) { implicit df => + val tsAttr = df(pushDownColName).expr + checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]], + data.map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tsAttr === ts1, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr <=> ts1, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr =!= ts1, classOf[NotEq[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tsAttr < ts2, classOf[Lt[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr > ts1, classOf[Gt[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i)))) + checkFilterPredicate(tsAttr <= ts1, classOf[LtEq[_]], resultFun(ts1)) + checkFilterPredicate(tsAttr >= ts4, classOf[GtEq[_]], resultFun(ts4)) + + checkFilterPredicate(Literal(ts1) === tsAttr, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts1) <=> tsAttr, classOf[Eq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts2) > tsAttr, classOf[Lt[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts3) < tsAttr, classOf[Gt[_]], resultFun(ts4)) + checkFilterPredicate(Literal(ts1) >= tsAttr, classOf[LtEq[_]], resultFun(ts1)) + checkFilterPredicate(Literal(ts4) <= tsAttr, classOf[GtEq[_]], resultFun(ts4)) + + checkFilterPredicate(!(tsAttr < ts4), classOf[GtEq[_]], resultFun(ts4)) + checkFilterPredicate(tsAttr < ts2 || tsAttr > ts3, classOf[Operators.Or], + Seq(Row(resultFun(ts1)), Row(resultFun(ts4)))) + } + } } // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. @@ -204,57 +224,52 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - boolean") { val data = true :: false :: Nil - Seq( - ( - spark.sqlContext.createDataFrame(data.map(x => ColA(Option(x)))), - "a", // zero nesting - (x: Any) => x), - ( - spark.sqlContext.createDataFrame(data.map(x => ColA(Option(ColB(Option(x)))))), - "a.b", // one level nesting - (x: Any) => Row(x)), - ( - spark.sqlContext.createDataFrame( - data.map(x => ColA(Option(ColB(Option(ColC(Option(x)))))))), - "a.b.c", // two level nesting - (x: Any) => Row(Row(x))) - ).foreach { case (i, pushDownColName, resultFun) => withParquetDataFrame(i) { implicit df => - val booleanAttr = df(pushDownColName).expr - checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]], - Seq(Row(resultFun(true)), Row(resultFun(false)))) - - checkFilterPredicate(booleanAttr === true, classOf[Eq[_]], resultFun(true)) - checkFilterPredicate(booleanAttr <=> true, classOf[Eq[_]], resultFun(true)) - checkFilterPredicate(booleanAttr =!= true, classOf[NotEq[_]], resultFun(false)) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => + withParquetDataFrame(df) { implicit df => + val booleanAttr = df(pushDownColName).expr + checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]], + Seq(Row(resultFun(true)), Row(resultFun(false)))) + + checkFilterPredicate(booleanAttr === true, classOf[Eq[_]], resultFun(true)) + checkFilterPredicate(booleanAttr <=> true, classOf[Eq[_]], resultFun(true)) + checkFilterPredicate(booleanAttr =!= true, classOf[NotEq[_]], resultFun(false)) + } } - }} + } test("filter pushdown - tinyint") { - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toByte))))) { implicit df => - assert(df.schema.head.dataType === ByteType) - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte, - classOf[Operators.Or], Seq(Row(1), Row(4))) + val df = toDF((1 to 4).map(i => Tuple1(Option(i.toByte)))) + withNestedDataFrame(df) { case (df, pushDownColName, resultFun) => + withParquetDataFrame(df) { implicit df => + val tinyIntAttr = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === ByteType) + checkFilterPredicate(tinyIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(tinyIntAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tinyIntAttr === 1.toByte, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(tinyIntAttr <=> 1.toByte, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(tinyIntAttr =!= 1.toByte, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(tinyIntAttr < 2.toByte, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(tinyIntAttr > 3.toByte, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(tinyIntAttr <= 1.toByte, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(tinyIntAttr >= 4.toByte, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1.toByte) === tinyIntAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1.toByte) <=> tinyIntAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2.toByte) > tinyIntAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3.toByte) < tinyIntAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1.toByte) >= tinyIntAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4.toByte) <= tinyIntAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(tinyIntAttr < 4.toByte), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(tinyIntAttr < 2.toByte || tinyIntAttr > 3.toByte, + classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 6e404cabfa7a9..f5abd30854e00 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -65,7 +65,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = quotedDataTypeMap(schema) + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. val newFilters = filters.filter(!_.containsNestedColumn) diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4f0bc415a174a..6213b5a58e8d0 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -22,8 +22,8 @@ import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable - import org.apache.spark.SparkException +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -64,7 +64,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = quotedDataTypeMap(schema) + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. val newFilters = filters.filter(!_.containsNestedColumn) From c40f5d035b6a4f82f0238285203cdbf74217b100 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 24 Mar 2020 21:15:41 -0700 Subject: [PATCH 4/9] address feedback --- .../apache/spark/sql/sources/filters.scala | 4 +- .../spark/sql/sources/FiltersSuite.scala | 56 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 8687be58dae80..319073e4475be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -52,7 +52,7 @@ sealed abstract class Filter { * @return each element is a column name as an array of string multi-identifier * @since 3.0.0 */ - def V2references: Array[Array[String]] = { + def v2references: Array[Array[String]] = { this.references.map(parseColumnPath(_).toArray) } @@ -60,7 +60,7 @@ sealed abstract class Filter { * If any of the references of this filter contains nested column */ private[sql] def containsNestedColumn: Boolean = { - this.V2references.exists(_.length > 1) + this.v2references.exists(_.length > 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala index 6629e3dfa0c99..33b2db57d9f0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala @@ -36,106 +36,106 @@ class FiltersSuite extends SparkFunSuite { test("EqualTo references") { withFieldNames { (name, fieldNames) => assert(EqualTo(name, "1").references.toSeq == Seq(name)) - assert(EqualTo(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(EqualTo(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) assert(EqualTo(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(EqualTo("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) - assert(EqualTo(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + assert(EqualTo(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq, Seq("b"))) - assert(EqualTo("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + assert(EqualTo("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) == Seq(Seq("b"), fieldNames.toSeq)) }} test("EqualNullSafe references") { withFieldNames { (name, fieldNames) => assert(EqualNullSafe(name, "1").references.toSeq == Seq(name)) - assert(EqualNullSafe(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(EqualNullSafe(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) assert(EqualNullSafe(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(EqualNullSafe("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) - assert(EqualNullSafe(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + assert(EqualNullSafe(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq, Seq("b"))) - assert(EqualNullSafe("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + assert(EqualNullSafe("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) == Seq(Seq("b"), fieldNames.toSeq)) }} test("GreaterThan references") { withFieldNames { (name, fieldNames) => assert(GreaterThan(name, "1").references.toSeq == Seq(name)) - assert(GreaterThan(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(GreaterThan(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) assert(GreaterThan(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(GreaterThan("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) - assert(GreaterThan(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + assert(GreaterThan(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq, Seq("b"))) - assert(GreaterThan("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + assert(GreaterThan("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) == Seq(Seq("b"), fieldNames.toSeq)) }} test("GreaterThanOrEqual references") { withFieldNames { (name, fieldNames) => assert(GreaterThanOrEqual(name, "1").references.toSeq == Seq(name)) - assert(GreaterThanOrEqual(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(GreaterThanOrEqual(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) assert(GreaterThanOrEqual(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(GreaterThanOrEqual("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) - assert(GreaterThanOrEqual(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + assert(GreaterThanOrEqual(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq, Seq("b"))) - assert(GreaterThanOrEqual("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + assert(GreaterThanOrEqual("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) == Seq(Seq("b"), fieldNames.toSeq)) }} test("LessThan references") { withFieldNames { (name, fieldNames) => assert(LessThan(name, "1").references.toSeq == Seq(name)) - assert(LessThan(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(LessThan(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) assert(LessThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) }} test("LessThanOrEqual references") { withFieldNames { (name, fieldNames) => assert(LessThanOrEqual(name, "1").references.toSeq == Seq(name)) - assert(LessThanOrEqual(name, "1").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(LessThanOrEqual(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) assert(LessThanOrEqual(name, EqualTo("b", "2")).references.toSeq == Seq(name, "b")) assert(LessThanOrEqual("b", EqualTo(name, "2")).references.toSeq == Seq("b", name)) - assert(LessThanOrEqual(name, EqualTo("b", "2")).V2references.toSeq.map(_.toSeq) + assert(LessThanOrEqual(name, EqualTo("b", "2")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq, Seq("b"))) - assert(LessThanOrEqual("b", EqualTo(name, "2")).V2references.toSeq.map(_.toSeq) + assert(LessThanOrEqual("b", EqualTo(name, "2")).v2references.toSeq.map(_.toSeq) == Seq(Seq("b"), fieldNames.toSeq)) }} test("In references") { withFieldNames { (name, fieldNames) => assert(In(name, Array("1")).references.toSeq == Seq(name)) - assert(In(name, Array("1")).V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(In(name, Array("1")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) assert(In(name, Array("1", EqualTo("b", "2"))).references.toSeq == Seq(name, "b")) assert(In("b", Array("1", EqualTo(name, "2"))).references.toSeq == Seq("b", name)) - assert(In(name, Array("1", EqualTo("b", "2"))).V2references.toSeq.map(_.toSeq) + assert(In(name, Array("1", EqualTo("b", "2"))).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq, Seq("b"))) - assert(In("b", Array("1", EqualTo(name, "2"))).V2references.toSeq.map(_.toSeq) + assert(In("b", Array("1", EqualTo(name, "2"))).v2references.toSeq.map(_.toSeq) == Seq(Seq("b"), fieldNames.toSeq)) }} test("IsNull references") { withFieldNames { (name, fieldNames) => assert(IsNull(name).references.toSeq == Seq(name)) - assert(IsNull(name).V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(IsNull(name).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} test("IsNotNull references") { withFieldNames { (name, fieldNames) => assert(IsNotNull(name).references.toSeq == Seq(name)) - assert(IsNull(name).V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(IsNull(name).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} test("And references") { withFieldNames { (name, fieldNames) => assert(And(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq == Seq(name, "b")) assert(And(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq == Seq("b", name)) - assert(And(EqualTo(name, "1"), EqualTo("b", "1")).V2references.toSeq.map(_.toSeq) == + assert(And(EqualTo(name, "1"), EqualTo("b", "1")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq, Seq("b"))) - assert(And(EqualTo("b", "1"), EqualTo(name, "1")).V2references.toSeq.map(_.toSeq) == + assert(And(EqualTo("b", "1"), EqualTo(name, "1")).v2references.toSeq.map(_.toSeq) == Seq(Seq("b"), fieldNames.toSeq)) }} @@ -143,24 +143,24 @@ class FiltersSuite extends SparkFunSuite { assert(Or(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq == Seq(name, "b")) assert(Or(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq == Seq("b", name)) - assert(Or(EqualTo(name, "1"), EqualTo("b", "1")).V2references.toSeq.map(_.toSeq) == + assert(Or(EqualTo(name, "1"), EqualTo("b", "1")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq, Seq("b"))) - assert(Or(EqualTo("b", "1"), EqualTo(name, "1")).V2references.toSeq.map(_.toSeq) == + assert(Or(EqualTo("b", "1"), EqualTo(name, "1")).v2references.toSeq.map(_.toSeq) == Seq(Seq("b"), fieldNames.toSeq)) }} test("StringStartsWith references") { withFieldNames { (name, fieldNames) => assert(StringStartsWith(name, "str").references.toSeq == Seq(name)) - assert(StringStartsWith(name, "str").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(StringStartsWith(name, "str").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} test("StringEndsWith references") { withFieldNames { (name, fieldNames) => assert(StringEndsWith(name, "str").references.toSeq == Seq(name)) - assert(StringEndsWith(name, "str").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(StringEndsWith(name, "str").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} test("StringContains references") { withFieldNames { (name, fieldNames) => assert(StringContains(name, "str").references.toSeq == Seq(name)) - assert(StringContains(name, "str").V2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) + assert(StringContains(name, "str").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} } From e9944e5dbf96ef17a24b2109033c2884cde4d2ca Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 24 Mar 2020 21:20:21 -0700 Subject: [PATCH 5/9] fix import --- .../apache/spark/sql/execution/datasources/orc/OrcFilters.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 6213b5a58e8d0..675e089153679 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable + import org.apache.spark.SparkException import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter From 44b310e423250c4635943c8e66323f37f03ef75b Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 24 Mar 2020 21:48:56 -0700 Subject: [PATCH 6/9] add more tests --- .../parquet/ParquetFilterSuite.scala | 306 +++++++++++------- 1 file changed, 184 insertions(+), 122 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index f040ed9d985f8..0d003b9026c05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -168,6 +168,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => withParquetDataFrame(df) { implicit df => val tsAttr = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === TimestampType) + checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(resultFun(i)))) @@ -228,6 +230,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => withParquetDataFrame(df) { implicit df => val booleanAttr = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === BooleanType) + checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]], Seq(Row(resultFun(true)), Row(resultFun(false)))) @@ -240,11 +244,13 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - tinyint") { - val df = toDF((1 to 4).map(i => Tuple1(Option(i.toByte)))) - withNestedDataFrame(df) { case (df, pushDownColName, resultFun) => + val data = (1 to 4).map(i => Tuple1(Option(i.toByte))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => withParquetDataFrame(df) { implicit df => val tinyIntAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === ByteType) + checkFilterPredicate(tinyIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(tinyIntAttr.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(resultFun(i)))) @@ -274,162 +280,218 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - smallint") { - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toShort))))) { implicit df => - assert(df.schema.head.dataType === ShortType) - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + val data = (1 to 4).map(i => Tuple1(Option(i.toShort))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => + withParquetDataFrame(df) { implicit df => + val smallIntAttr = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === ShortType) - checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate(smallIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(smallIntAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4) + checkFilterPredicate(smallIntAttr === 1.toShort, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(smallIntAttr <=> 1.toShort, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(smallIntAttr =!= 1.toShort, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(smallIntAttr < 2.toShort, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(smallIntAttr > 3.toShort, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(smallIntAttr <= 1.toShort, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(smallIntAttr >= 4.toShort, classOf[GtEq[_]], resultFun(4)) - checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort, - classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(Literal(1.toShort) === smallIntAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1.toShort) <=> smallIntAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2.toShort) > smallIntAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3.toShort) < smallIntAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1.toShort) >= smallIntAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4.toShort) <= smallIntAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(smallIntAttr < 4.toShort), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(smallIntAttr < 2.toShort || smallIntAttr > 3.toShort, + classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - integer") { - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i))))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + val data = (1 to 4).map(i => Tuple1(Option(i))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => + withParquetDataFrame(df) { implicit df => + val intAttr = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === IntegerType) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate(intAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(intAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(intAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(intAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(intAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(intAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(intAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(intAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(intAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === intAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> intAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > intAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < intAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= intAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= intAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(intAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(intAttr < 2 || intAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - long") { - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toLong))))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + val data = (1 to 4).map(i => Tuple1(Option(i.toLong))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => + withParquetDataFrame(df) { implicit df => + val longAttr = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === LongType) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate(longAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(longAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(longAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(longAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(longAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(longAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(longAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(longAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(longAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === longAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> longAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > longAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < longAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= longAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= longAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(longAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(longAttr < 2 || longAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - float") { - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toFloat))))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + val data = (1 to 4).map(i => Tuple1(Option(i.toFloat))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => + withParquetDataFrame(df) { implicit df => + val floatAttr = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === FloatType) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate(floatAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(floatAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(floatAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(floatAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(floatAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(floatAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(floatAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(floatAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(floatAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === floatAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> floatAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > floatAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < floatAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= floatAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= floatAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(floatAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(floatAttr < 2 || floatAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - double") { - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toDouble))))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + val data = (1 to 4).map(i => Tuple1(Option(i.toDouble))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => + withParquetDataFrame(df) { implicit df => + val doubleAttr = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === DoubleType) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate(doubleAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(doubleAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(doubleAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(doubleAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(doubleAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(doubleAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(doubleAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(doubleAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(doubleAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === doubleAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> doubleAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > doubleAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < doubleAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= doubleAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= doubleAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(doubleAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(doubleAttr < 2 || doubleAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } test("filter pushdown - string") { - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(i.toString)))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") - checkFilterPredicate('_1 <=> "1", classOf[Eq[_]], "1") - checkFilterPredicate( - '_1 =!= "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") - checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") - checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - - checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") - checkFilterPredicate(Literal("1") <=> '_1, classOf[Eq[_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + val data = (1 to 4).map(i => Tuple1(Option(i.toString))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => + withParquetDataFrame(df) { implicit df => + val stringAttr = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === StringType) + + checkFilterPredicate(stringAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(stringAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i.toString)))) + + checkFilterPredicate(stringAttr === "1", classOf[Eq[_]], resultFun("1")) + checkFilterPredicate(stringAttr <=> "1", classOf[Eq[_]], resultFun("1")) + checkFilterPredicate(stringAttr =!= "1", classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i.toString)))) + + checkFilterPredicate(stringAttr < "2", classOf[Lt[_]], resultFun("1")) + checkFilterPredicate(stringAttr > "3", classOf[Gt[_]], resultFun("4")) + checkFilterPredicate(stringAttr <= "1", classOf[LtEq[_]], resultFun("1")) + checkFilterPredicate(stringAttr >= "4", classOf[GtEq[_]], resultFun("4")) + + checkFilterPredicate(Literal("1") === stringAttr, classOf[Eq[_]], resultFun("1")) + checkFilterPredicate(Literal("1") <=> stringAttr, classOf[Eq[_]], resultFun("1")) + checkFilterPredicate(Literal("2") > stringAttr, classOf[Lt[_]], resultFun("1")) + checkFilterPredicate(Literal("3") < stringAttr, classOf[Gt[_]], resultFun("4")) + checkFilterPredicate(Literal("1") >= stringAttr, classOf[LtEq[_]], resultFun("1")) + checkFilterPredicate(Literal("4") <= stringAttr, classOf[GtEq[_]], resultFun("4")) + + checkFilterPredicate(!(stringAttr < "4"), classOf[GtEq[_]], resultFun("4")) + checkFilterPredicate(stringAttr < "2" || stringAttr > "3", classOf[Operators.Or], + Seq(Row(resultFun("1")), Row(resultFun("4")))) + } } } From 4883e68f636b5e2b9b845f4e23909a2854d28063 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 25 Mar 2020 21:36:55 -0700 Subject: [PATCH 7/9] Address feedback --- .../parquet/ParquetFilterSuite.scala | 256 +++++++++--------- 1 file changed, 135 insertions(+), 121 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 0d003b9026c05..9a4368ed92f97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -21,6 +21,9 @@ import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} @@ -42,6 +45,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} + /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. * @@ -103,30 +107,14 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } - private def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit df: DataFrame): Unit = { - def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { - df.rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted - } - } - - checkFilterPredicate(df, predicate, filterClass, checkBinaryAnswer _, expected) - } - - private def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit df: DataFrame): Unit = { - checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) - } - /** * Takes single level `inputDF` dataframe to generate multi-level nested * dataframes as new test data. */ private def withNestedDataFrame(inputDF: DataFrame) - (testCases: (DataFrame, String, Any => Any) => Unit): Unit = { + (runTests: (DataFrame, String, Any => Any) => Unit): Unit = { + assert(inputDF.schema.fields.length == 1) + assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType]) val df = inputDF.toDF("temp") Seq( ( @@ -153,7 +141,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared (x: Any) => Row(x) ) ).foreach { case (df, pushDownColName, resultTransFun) => - testCases(df, pushDownColName, resultTransFun) + runTests(df, pushDownColName, resultTransFun) } } @@ -165,8 +153,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val ts4 = data(3) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => - withParquetDataFrame(df) { implicit df => + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => val tsAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === TimestampType) @@ -227,8 +215,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - boolean") { val data = true :: false :: Nil import testImplicits._ - withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => - withParquetDataFrame(df) { implicit df => + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => val booleanAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === BooleanType) @@ -246,8 +234,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - tinyint") { val data = (1 to 4).map(i => Tuple1(Option(i.toByte))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => - withParquetDataFrame(df) { implicit df => + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => val tinyIntAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === ByteType) @@ -282,8 +270,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - smallint") { val data = (1 to 4).map(i => Tuple1(Option(i.toShort))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => - withParquetDataFrame(df) { implicit df => + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => val smallIntAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === ShortType) @@ -318,8 +306,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - integer") { val data = (1 to 4).map(i => Tuple1(Option(i))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => - withParquetDataFrame(df) { implicit df => + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => val intAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === IntegerType) @@ -354,8 +342,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - long") { val data = (1 to 4).map(i => Tuple1(Option(i.toLong))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => - withParquetDataFrame(df) { implicit df => + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => val longAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === LongType) @@ -390,8 +378,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - float") { val data = (1 to 4).map(i => Tuple1(Option(i.toFloat))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => - withParquetDataFrame(df) { implicit df => + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => val floatAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === FloatType) @@ -426,8 +414,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - double") { val data = (1 to 4).map(i => Tuple1(Option(i.toDouble))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => - withParquetDataFrame(df) { implicit df => + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => val doubleAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === DoubleType) @@ -462,8 +450,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - string") { val data = (1 to 4).map(i => Tuple1(Option(i.toString))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (df, pushDownColName, resultFun) => - withParquetDataFrame(df) { implicit df => + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => val stringAttr = df(pushDownColName).expr assert(df(pushDownColName).expr.dataType === StringType) @@ -500,32 +488,39 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) } - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(i.b)))) { implicit df => - checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b) - - checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkBinaryFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate( - '_1 =!= 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate(Literal(1.b) <=> '_1, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate( - '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) + val data = (1 to 4).map(i => Tuple1(Option(i.b))) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val binaryAttr: Expression = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === BinaryType) + + checkFilterPredicate(binaryAttr === 1.b, classOf[Eq[_]], resultFun(1.b)) + checkFilterPredicate(binaryAttr <=> 1.b, classOf[Eq[_]], resultFun(1.b)) + + checkFilterPredicate(binaryAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(binaryAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i.b)))) + + checkFilterPredicate(binaryAttr =!= 1.b, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i.b)))) + + checkFilterPredicate(binaryAttr < 2.b, classOf[Lt[_]], resultFun(1.b)) + checkFilterPredicate(binaryAttr > 3.b, classOf[Gt[_]], resultFun(4.b)) + checkFilterPredicate(binaryAttr <= 1.b, classOf[LtEq[_]], resultFun(1.b)) + checkFilterPredicate(binaryAttr >= 4.b, classOf[GtEq[_]], resultFun(4.b)) + + checkFilterPredicate(Literal(1.b) === binaryAttr, classOf[Eq[_]], resultFun(1.b)) + checkFilterPredicate(Literal(1.b) <=> binaryAttr, classOf[Eq[_]], resultFun(1.b)) + checkFilterPredicate(Literal(2.b) > binaryAttr, classOf[Lt[_]], resultFun(1.b)) + checkFilterPredicate(Literal(3.b) < binaryAttr, classOf[Gt[_]], resultFun(4.b)) + checkFilterPredicate(Literal(1.b) >= binaryAttr, classOf[LtEq[_]], resultFun(1.b)) + checkFilterPredicate(Literal(4.b) <= binaryAttr, classOf[GtEq[_]], resultFun(4.b)) + + checkFilterPredicate(!(binaryAttr < 4.b), classOf[GtEq[_]], resultFun(4.b)) + checkFilterPredicate(binaryAttr < 2.b || binaryAttr > 3.b, classOf[Operators.Or], + Seq(Row(resultFun(1.b)), Row(resultFun(4.b)))) + } } } @@ -534,40 +529,53 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared def date: Date = Date.valueOf(s) } - val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21") - - withParquetDataFrame(toDF(data.map(i => Tuple1(i.date)))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date))) - - checkFilterPredicate('_1 === "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) - checkFilterPredicate('_1 <=> "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) - checkFilterPredicate('_1 =!= "2018-03-18".date, classOf[NotEq[_]], - Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(i.date))) - - checkFilterPredicate('_1 < "2018-03-19".date, classOf[Lt[_]], "2018-03-18".date) - checkFilterPredicate('_1 > "2018-03-20".date, classOf[Gt[_]], "2018-03-21".date) - checkFilterPredicate('_1 <= "2018-03-18".date, classOf[LtEq[_]], "2018-03-18".date) - checkFilterPredicate('_1 >= "2018-03-21".date, classOf[GtEq[_]], "2018-03-21".date) + val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21").map(_.date) + import testImplicits._ + withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val dateAttr: Expression = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === DateType) - checkFilterPredicate( - Literal("2018-03-18".date) === '_1, classOf[Eq[_]], "2018-03-18".date) - checkFilterPredicate( - Literal("2018-03-18".date) <=> '_1, classOf[Eq[_]], "2018-03-18".date) - checkFilterPredicate( - Literal("2018-03-19".date) > '_1, classOf[Lt[_]], "2018-03-18".date) - checkFilterPredicate( - Literal("2018-03-20".date) < '_1, classOf[Gt[_]], "2018-03-21".date) - checkFilterPredicate( - Literal("2018-03-18".date) >= '_1, classOf[LtEq[_]], "2018-03-18".date) - checkFilterPredicate( - Literal("2018-03-21".date) <= '_1, classOf[GtEq[_]], "2018-03-21".date) + checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]], + data.map(i => Row.apply(resultFun(i)))) - checkFilterPredicate(!('_1 < "2018-03-21".date), classOf[GtEq[_]], "2018-03-21".date) - checkFilterPredicate( - '_1 < "2018-03-19".date || '_1 > "2018-03-20".date, - classOf[Operators.Or], - Seq(Row("2018-03-18".date), Row("2018-03-21".date))) + checkFilterPredicate(dateAttr === "2018-03-18".date, classOf[Eq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(dateAttr <=> "2018-03-18".date, classOf[Eq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(dateAttr =!= "2018-03-18".date, classOf[NotEq[_]], + Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(resultFun(i.date)))) + + checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]], + resultFun("2018-03-21".date)) + checkFilterPredicate(dateAttr <= "2018-03-18".date, classOf[LtEq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(dateAttr >= "2018-03-21".date, classOf[GtEq[_]], + resultFun("2018-03-21".date)) + + checkFilterPredicate(Literal("2018-03-18".date) === dateAttr, classOf[Eq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr, classOf[Eq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(Literal("2018-03-19".date) > dateAttr, classOf[Lt[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(Literal("2018-03-20".date) < dateAttr, classOf[Gt[_]], + resultFun("2018-03-21".date)) + checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr, classOf[LtEq[_]], + resultFun("2018-03-18".date)) + checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr, classOf[GtEq[_]], + resultFun("2018-03-21".date)) + + checkFilterPredicate(!(dateAttr < "2018-03-21".date), classOf[GtEq[_]], + resultFun("2018-03-21".date)) + checkFilterPredicate( + dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date, + classOf[Operators.Or], + Seq(Row(resultFun("2018-03-18".date)), Row(resultFun("2018-03-21".date)))) + } } } @@ -612,33 +620,39 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared (false, DecimalType.MAX_PRECISION) // binaryWriterUsingUnscaledBytes ).foreach { case (legacyFormat, precision) => withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> legacyFormat.toString) { - val schema = StructType.fromDDL(s"a decimal($precision, 2)") val rdd = spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) - val dataFrame = spark.createDataFrame(rdd, schema) - withParquetDataFrame(dataFrame) { implicit df => - assert(df.schema === schema) - checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('a === 1, classOf[Eq[_]], 1) - checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('a < 2, classOf[Lt[_]], 1) - checkFilterPredicate('a > 3, classOf[Gt[_]], 4) - checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + val dataFrame = spark.createDataFrame(rdd, StructType.fromDDL(s"a decimal($precision, 2)")) + withNestedDataFrame(dataFrame) { case (inputDF, pushDownColName, resultFun) => + withParquetDataFrame(inputDF) { implicit df => + val decimalAttr: Expression = df(pushDownColName).expr + assert(df(pushDownColName).expr.dataType === DecimalType(precision, 2)) + + checkFilterPredicate(decimalAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(decimalAttr.isNotNull, classOf[NotEq[_]], + (1 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(decimalAttr === 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(decimalAttr <=> 1, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(decimalAttr =!= 1, classOf[NotEq[_]], + (2 to 4).map(i => Row.apply(resultFun(i)))) + + checkFilterPredicate(decimalAttr < 2, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(decimalAttr > 3, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(decimalAttr <= 1, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(decimalAttr >= 4, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(Literal(1) === decimalAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(1) <=> decimalAttr, classOf[Eq[_]], resultFun(1)) + checkFilterPredicate(Literal(2) > decimalAttr, classOf[Lt[_]], resultFun(1)) + checkFilterPredicate(Literal(3) < decimalAttr, classOf[Gt[_]], resultFun(4)) + checkFilterPredicate(Literal(1) >= decimalAttr, classOf[LtEq[_]], resultFun(1)) + checkFilterPredicate(Literal(4) <= decimalAttr, classOf[GtEq[_]], resultFun(4)) + + checkFilterPredicate(!(decimalAttr < 4), classOf[GtEq[_]], resultFun(4)) + checkFilterPredicate(decimalAttr < 2 || decimalAttr > 3, classOf[Operators.Or], + Seq(Row(resultFun(1)), Row(resultFun(4)))) + } } } } From f732664630c6465ef8d4ec58dbc3965fcee82f3b Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 25 Mar 2020 23:04:24 -0700 Subject: [PATCH 8/9] clean up imports and unused functions --- .../datasources/parquet/ParquetFilters.scala | 1 + .../parquet/ParquetFilterSuite.scala | 100 +++++++++--------- .../datasources/parquet/ParquetTest.scala | 4 - 3 files changed, 53 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 65731fbb2590d..f206f59dacdc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -183,6 +183,7 @@ class ParquetFilters( (n: Array[String], v: Any) => FilterApi.eq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => (n: Array[String], v: Any) => FilterApi.eq( intColumn(n), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 9a4368ed92f97..d1161e33b0941 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -21,9 +21,6 @@ import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag - import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} @@ -45,7 +42,6 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} - /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. * @@ -112,7 +108,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared * dataframes as new test data. */ private def withNestedDataFrame(inputDF: DataFrame) - (runTests: (DataFrame, String, Any => Any) => Unit): Unit = { + (runTest: (DataFrame, String, Any => Any) => Unit): Unit = { assert(inputDF.schema.fields.length == 1) assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType]) val df = inputDF.toDF("temp") @@ -140,8 +136,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared "`a.b`.`c.d`", // one level nesting with column names containing `dots` (x: Any) => Row(x) ) - ).foreach { case (df, pushDownColName, resultTransFun) => - runTests(df, pushDownColName, resultTransFun) + ).foreach { case (df, colName, resultFun) => + runTest(df, colName, resultFun) } } @@ -153,10 +149,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val ts4 = data(3) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val tsAttr = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === TimestampType) + val tsAttr = df(colName).expr + assert(df(colName).expr.dataType === TimestampType) checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]], @@ -213,12 +209,12 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - boolean") { - val data = true :: false :: Nil + val data = (true :: false :: Nil).map(b => Tuple1.apply(Option(b))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val booleanAttr = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === BooleanType) + val booleanAttr = df(colName).expr + assert(df(colName).expr.dataType === BooleanType) checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]], @@ -234,10 +230,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - tinyint") { val data = (1 to 4).map(i => Tuple1(Option(i.toByte))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val tinyIntAttr = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === ByteType) + val tinyIntAttr = df(colName).expr + assert(df(colName).expr.dataType === ByteType) checkFilterPredicate(tinyIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(tinyIntAttr.isNotNull, classOf[NotEq[_]], @@ -270,10 +266,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - smallint") { val data = (1 to 4).map(i => Tuple1(Option(i.toShort))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val smallIntAttr = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === ShortType) + val smallIntAttr = df(colName).expr + assert(df(colName).expr.dataType === ShortType) checkFilterPredicate(smallIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(smallIntAttr.isNotNull, classOf[NotEq[_]], @@ -306,10 +302,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - integer") { val data = (1 to 4).map(i => Tuple1(Option(i))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val intAttr = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === IntegerType) + val intAttr = df(colName).expr + assert(df(colName).expr.dataType === IntegerType) checkFilterPredicate(intAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(intAttr.isNotNull, classOf[NotEq[_]], @@ -342,10 +338,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - long") { val data = (1 to 4).map(i => Tuple1(Option(i.toLong))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val longAttr = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === LongType) + val longAttr = df(colName).expr + assert(df(colName).expr.dataType === LongType) checkFilterPredicate(longAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(longAttr.isNotNull, classOf[NotEq[_]], @@ -378,10 +374,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - float") { val data = (1 to 4).map(i => Tuple1(Option(i.toFloat))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val floatAttr = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === FloatType) + val floatAttr = df(colName).expr + assert(df(colName).expr.dataType === FloatType) checkFilterPredicate(floatAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(floatAttr.isNotNull, classOf[NotEq[_]], @@ -414,10 +410,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - double") { val data = (1 to 4).map(i => Tuple1(Option(i.toDouble))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val doubleAttr = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === DoubleType) + val doubleAttr = df(colName).expr + assert(df(colName).expr.dataType === DoubleType) checkFilterPredicate(doubleAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(doubleAttr.isNotNull, classOf[NotEq[_]], @@ -450,10 +446,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - string") { val data = (1 to 4).map(i => Tuple1(Option(i.toString))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val stringAttr = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === StringType) + val stringAttr = df(colName).expr + assert(df(colName).expr.dataType === StringType) checkFilterPredicate(stringAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(stringAttr.isNotNull, classOf[NotEq[_]], @@ -490,10 +486,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val data = (1 to 4).map(i => Tuple1(Option(i.b))) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val binaryAttr: Expression = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === BinaryType) + val binaryAttr: Expression = df(colName).expr + assert(df(colName).expr.dataType === BinaryType) checkFilterPredicate(binaryAttr === 1.b, classOf[Eq[_]], resultFun(1.b)) checkFilterPredicate(binaryAttr <=> 1.b, classOf[Eq[_]], resultFun(1.b)) @@ -531,10 +527,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21").map(_.date) import testImplicits._ - withNestedDataFrame(data.toDF()) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val dateAttr: Expression = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === DateType) + val dateAttr: Expression = df(colName).expr + assert(df(colName).expr.dataType === DateType) checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]], @@ -603,7 +599,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> ParquetOutputTimestampType.INT96.toString) { - withParquetDataFrame(toDF(millisData.map(i => Tuple1(i)))) { implicit df => + import testImplicits._ + withParquetDataFrame(millisData.map(i => Tuple1(i)).toDF()) { implicit df => val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) assertResult(None) { createParquetFilters(schema).createFilter(sources.IsNull("_1")) @@ -623,10 +620,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val rdd = spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) val dataFrame = spark.createDataFrame(rdd, StructType.fromDDL(s"a decimal($precision, 2)")) - withNestedDataFrame(dataFrame) { case (inputDF, pushDownColName, resultFun) => + withNestedDataFrame(dataFrame) { case (inputDF, colName, resultFun) => withParquetDataFrame(inputDF) { implicit df => - val decimalAttr: Expression = df(pushDownColName).expr - assert(df(pushDownColName).expr.dataType === DecimalType(precision, 2)) + val decimalAttr: Expression = df(colName).expr + assert(df(colName).expr.dataType === DecimalType(precision, 2)) checkFilterPredicate(decimalAttr.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate(decimalAttr.isNotNull, classOf[NotEq[_]], @@ -1166,7 +1163,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("SPARK-16371 Do not push down filters when inner name and outer name are the same") { - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Tuple1(i))))) { implicit df => + import testImplicits._ + withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i))).toDF()) { implicit df => // Here the schema becomes as below: // // root @@ -1308,7 +1306,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("filter pushdown - StringStartsWith") { - withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(i + "str" + i)))) { implicit df => + withParquetDataFrame { + import testImplicits._ + (1 to 4).map(i => Tuple1(i + "str" + i)).toDF() + } { implicit df => checkFilterPredicate( '_1.startsWith("").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], @@ -1354,7 +1355,10 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } // SPARK-28371: make sure filter is null-safe. - withParquetDataFrame(toDF(Seq(Tuple1[String](null)))) { implicit df => + withParquetDataFrame { + import testImplicits._ + Seq(Tuple1[String](null)).toDF() + } { implicit df => checkFilterPredicate( '_1.startsWith("blah").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 6e5c562571997..f2dbc536ac566 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -62,10 +62,6 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { (data: Seq[T]) (f: String => Unit): Unit = withDataSourceFile(data)(f) - protected def toDF[T <: Product: ClassTag: TypeTag](data: Seq[T]): DataFrame = { - spark.createDataFrame(data) - } - /** * Writes `df` dataframe to a Parquet file and reads it back as a [[DataFrame]], * which is then passed to `f`. The Parquet file will be deleted after `f` returns. From 5fd97c0a90eb1885a93fffb9d04a262b35f62bc3 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 26 Mar 2020 00:59:38 -0700 Subject: [PATCH 9/9] Trigger Build