diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index f8c7e2c826a36..7681dc8dfb37a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -106,13 +106,16 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.newScanBuilder() + val (withSubquery, withoutSubquery) = filters.partition(SubqueryExpression.hasSubquery) val normalizedFilters = DataSourceStrategy.normalizeFilters( - filters.filterNot(SubqueryExpression.hasSubquery), relation.output) + withoutSubquery, relation.output) // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFilters) = pushFilters(scanBuilder, normalizedFilters) + val (pushedFilters, postScanFiltersWithoutSubquery) = + pushFilters(scanBuilder, normalizedFilters) + val postScanFilters = postScanFiltersWithoutSubquery ++ withSubquery val (scan, output) = pruneColumns(scanBuilder, relation, project ++ postScanFilters) logInfo( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 587cfa9bd6647..4e071c5af6a62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -392,6 +392,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-27411: DataSourceV2Strategy should not eliminate subquery") { + withTempView("t1") { + val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + Seq(2, 3).toDF("a").createTempView("t1") + val df = t2.where("i < (select max(a) from t1)").select('i) + val subqueries = df.queryExecution.executedPlan.collect { + case p => p.subqueries + }.flatten + assert(subqueries.length == 1) + checkAnswer(df, (0 until 3).map(i => Row(i))) + } + } }