diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index 144e9ad129feb..d0f38c12427c3 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -62,10 +62,6 @@ case class AvroScan( pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options && equivalentFilters(pushedFilters, a.pushedFilters) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala index 9420608bb22ce..8fae89a945826 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.v2.avro import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -31,7 +31,7 @@ class AvroScanBuilder ( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { AvroScan( @@ -41,17 +41,16 @@ class AvroScanBuilder ( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.avroFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala new file mode 100644 index 0000000000000..9c2a4ac78a24a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal.connector + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.sources.Filter + +/** + * A mix-in interface for {@link FileScanBuilder}. File sources can implement this interface to + * push down filters to the file source. The pushed down filters will be separated into partition + * filters and data filters. Partition filters are used for partition pruning and data filters are + * used to reduce the size of the data to be read. + */ +trait SupportsPushDownCatalystFilters { + + /** + * Pushes down catalyst Expression filters (which will be separated into partition filters and + * data filters), and returns data filters that need to be evaluated after scanning. + */ + def pushFilters(filters: Seq[Expression]): Seq[Expression] + + /** + * Returns the data filters that are pushed to the data source via + * {@link #pushFilters(Expression[])}. + */ + def pushedFilters: Array[Filter] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index fcd95a27bf8ca..67d03998a2a24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -28,6 +28,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.SparkUpgradeException import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_LEGACY_INT96, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions @@ -39,7 +40,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils -object DataSourceUtils { +object DataSourceUtils extends PredicateHelper { /** * The key to use for storing partitionBy columns as options. */ @@ -242,4 +243,22 @@ object DataSourceUtils { options } } + + def getPartitionFiltersAndDataFilters( + partitionSchema: StructType, + normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val partitionColumns = normalizedFilters.flatMap { expr => + expr.collect { + case attr: AttributeReference if partitionSchema.names.contains(attr.name) => + attr + } + } + val partitionSet = AttributeSet(partitionColumns) + val (partitionFilters, dataFilters) = normalizedFilters.partition(f => + f.references.subsetOf(partitionSet) + ) + val extraPartitionFilter = + dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) + (ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 0927027bee0bc..2e8e5426d47be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,52 +17,24 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan} -import org.apache.spark.sql.types.StructType /** * Prune the partitions of file source based table using partition filters. Currently, this rule - * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]] and [[DataSourceV2ScanRelation]] - * with [[FileScan]]. + * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]]. * * For [[HadoopFsRelation]], the location will be replaced by pruned file index, and corresponding * statistics will be updated. And the partition filters will be kept in the filters of returned * logical plan. - * - * For [[DataSourceV2ScanRelation]], both partition filters and data filters will be added to - * its underlying [[FileScan]]. And the partition filters will be removed in the filters of - * returned logical plan. */ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] with PredicateHelper { - private def getPartitionKeyFiltersAndDataFilters( - sparkSession: SparkSession, - relation: LeafNode, - partitionSchema: StructType, - filters: Seq[Expression], - output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = { - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output) - val partitionColumns = - relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) - val partitionSet = AttributeSet(partitionColumns) - val (partitionFilters, dataFilters) = normalizedFilters.partition(f => - f.references.subsetOf(partitionSet) - ) - val extraPartitionFilter = - dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) - - (ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters) - } - private def rebuildPhysicalOperation( projects: Seq[NamedExpression], filters: Seq[Expression], @@ -91,12 +63,14 @@ private[sql] object PruneFileSourcePartitions _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, filters, + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), logicalRelation.output) + val (partitionKeyFilters, _) = DataSourceUtils + .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) if (partitionKeyFilters.nonEmpty) { - val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) + val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession) // Change table stats based on the sizeInBytes of pruned files @@ -117,23 +91,5 @@ private[sql] object PruneFileSourcePartitions } else { op } - - case op @ PhysicalOperation(projects, filters, - v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) - if filters.nonEmpty => - val (partitionKeyFilters, dataFilters) = - getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation, - scan.readPartitionSchema, filters, output) - // The dataFilters are pushed down only once - if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) { - val prunedV2Relation = - v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters)) - // The pushed down partition filters don't need to be reevaluated. - val afterScanFilters = - ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) - rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation) - } else { - op - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index b20270275d9fa..8b0328cabc5a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -71,12 +71,6 @@ trait FileScan extends Scan */ def dataFilters: Seq[Expression] - /** - * Create a new `FileScan` instance from the current one - * with different `partitionFilters` and `dataFilters` - */ - def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan - /** * If a file with `path` is unsplittable, return the unsplittable reason, * otherwise return `None`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 97874e8f4932e..309f045201140 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -16,19 +16,30 @@ */ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.SparkSession +import scala.collection.mutable + +import org.apache.spark.sql.{sources, SparkSession} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.internal.connector.SupportsPushDownCatalystFilters +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType abstract class FileScanBuilder( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns { + dataSchema: StructType) + extends ScanBuilder + with SupportsPushDownRequiredColumns + with SupportsPushDownCatalystFilters { private val partitionSchema = fileIndex.partitionSchema private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis protected val supportsNestedSchemaPruning = false protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields) + protected var partitionFilters = Seq.empty[Expression] + protected var dataFilters = Seq.empty[Expression] + protected var pushedDataFilters = Array.empty[Filter] override def pruneColumns(requiredSchema: StructType): Unit = { // [SPARK-30107] While `requiredSchema` might have pruned nested columns, @@ -48,7 +59,7 @@ abstract class FileScanBuilder( StructType(fields) } - protected def readPartitionSchema(): StructType = { + def readPartitionSchema(): StructType = { val requiredNameSet = createRequiredNameSet() val fields = partitionSchema.fields.filter { field => val colName = PartitioningUtils.getColName(field, isCaseSensitive) @@ -57,6 +68,31 @@ abstract class FileScanBuilder( StructType(fields) } + override def pushFilters(filters: Seq[Expression]): Seq[Expression] = { + val (partitionFilters, dataFilters) = + DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filters) + this.partitionFilters = partitionFilters + this.dataFilters = dataFilters + val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] + for (filterExpr <- dataFilters) { + val translated = DataSourceStrategy.translateFilter(filterExpr, true) + if (translated.nonEmpty) { + translatedFilters += translated.get + } + } + pushedDataFilters = pushDataFilters(translatedFilters.toArray) + dataFilters + } + + override def pushedFilters: Array[Filter] = pushedDataFilters + + /* + * Push down data filters to the file source, so the data filters can be evaluated there to + * reduce the size of the data to be read. By default, data filters are not pushed down. + * File source needs to implement this method to push down data filters. + */ + protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter] + private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index acc645741819e..7229488026bc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -71,6 +69,9 @@ object PushDownUtils extends PredicateHelper { } (r.pushedFilters(), (untranslatableExprs ++ postScanFilters).toSeq) + case f: FileScanBuilder => + val postScanFilters = f.pushFilters(filters) + (f.pushedFilters, postScanFilters) case _ => (Nil, filters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 3f77b2147f9ca..cc3c146106670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.csv.CSVDataSource -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -84,10 +84,6 @@ case class CSVScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options && equivalentFilters(pushedFilters, c.pushedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala index f7a79bf31948e..2b6edd4f357ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.csv import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -32,7 +32,7 @@ case class CSVScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { CSVScan( @@ -42,17 +42,16 @@ case class CSVScanBuilder( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.csvFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index 29eb8bec9a589..9ab367136fc97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.json.JsonDataSource -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -83,10 +83,6 @@ case class JsonScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options && equivalentFilters(pushedFilters, j.pushedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala index cf1204566ddbd..c581617a4b7e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.json import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -31,7 +31,7 @@ class JsonScanBuilder ( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { JsonScan( sparkSession, @@ -40,17 +40,16 @@ class JsonScanBuilder ( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.jsonFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 8fa7f8dc41ead..7619e3c503139 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -68,8 +68,4 @@ case class OrcScan( override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index dc59526bb316b..cfa396f5482f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2.orc import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder @@ -35,7 +35,7 @@ case class OrcScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -45,20 +45,17 @@ case class OrcScanBuilder( override protected val supportsNestedSchemaPruning: Boolean = true override def build(): Scan = { - OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, - readDataSchema(), readPartitionSchema(), options, pushedFilters()) + OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), + readPartitionSchema(), options, pushedDataFilters, partitionFilters, dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { val dataTypeMap = OrcFilters.getSearchableTypeMap( readDataSchema(), SQLConf.get.caseSensitiveAnalysis) - _pushedFilters = OrcFilters.convertibleFilters(dataTypeMap, filters).toArray + OrcFilters.convertibleFilters(dataTypeMap, dataFilters).toArray + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 60573ba10ccb6..e277e334845c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -105,8 +105,4 @@ case class ParquetScan( override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 4b3f4e7edca6c..ff5137e928db3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder @@ -35,7 +35,7 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -63,17 +63,12 @@ case class ParquetScanBuilder( // The rebase mode doesn't matter here because the filters are used to determine // whether they is convertible. LegacyBehaviorPolicy.CORRECTED) - parquetFilters.convertibleFilters(this.filters).toArray + parquetFilters.convertibleFilters(pushedDataFilters).toArray } override protected val supportsNestedSchemaPruning: Boolean = true - private var filters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - this.filters - } + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters // Note: for Parquet, the actual filter push down happens in [[ParquetPartitionReaderFactory]]. // It requires the Parquet physical schema to determine whether a filter is convertible. @@ -82,6 +77,6 @@ case class ParquetScanBuilder( override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options) + readPartitionSchema(), pushedParquetFilters, options, partitionFilters, dataFilters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index a401d296d3eaf..c7b0fec34b4e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.text.TextOptions -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -72,10 +72,6 @@ case class TextScan( readPartitionSchema, textOptions) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case t: TextScan => super.equals(t) && options == t.options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala index d929468b1b8b1..0ebb098bfc1df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala @@ -33,6 +33,7 @@ case class TextScanBuilder( extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { - TextScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options) + TextScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options, + partitionFilters, dataFilters) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index e5c82603d8893..f7f1d0b847cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3023,16 +3023,14 @@ class JsonV2Suite extends JsonSuite { withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { file => val scanBuilder = getBuilder(file.getCanonicalPath) - assert(scanBuilder.pushFilters(filters) === filters) - assert(scanBuilder.pushedFilters() === filters) + assert(scanBuilder.pushDataFilters(filters) === filters) } } withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "false") { withTempPath { file => val scanBuilder = getBuilder(file.getCanonicalPath) - assert(scanBuilder.pushFilters(filters) === filters) - assert(scanBuilder.pushedFilters() === Array.empty[sources.Filter]) + assert(scanBuilder.pushDataFilters(filters) === Array.empty[sources.Filter]) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 526dad91e5e19..02f10aa0af424 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -447,7 +447,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY), SUM(BONUS)" + "PushedAggregates: [SUM(SALARY), SUM(BONUS)]" checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(47100.0))) @@ -465,4 +465,22 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } checkAnswer(df2, Seq(Row(53000.00))) } + + test("scan with aggregate push-down: aggregate with partially pushed down filters" + + "will NOT push down") { + val df = spark.table("h2.test.employee") + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val query = df.select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter("SALARY > 100") + .filter(name($"shortName")) + .agg(sum($"SALARY").as("sum_salary")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: []" + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(29000.0))) + } }