From 29b068c66caa9fc111701e72fa727072227bf2be Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 7 Jul 2021 04:25:54 +0000 Subject: [PATCH] [SPARK-36028][SQL] Allow Project to host outer references in scalar subqueries This PR allows the `Project` node to host outer references in scalar subqueries when `decorrelateInnerQuery` is enabled. It is already supported by the new decorrelation framework and the `RewriteCorrelatedScalarSubquery` rule. Note currently by default all correlated subqueries will be decorrelated, which is not necessarily the most optimal approach. Consider `SELECT (SELECT c1) FROM t`. This should be optimized as `SELECT c1 FROM t` instead of rewriting it as a left outer join. This will be done in a separate PR to optimize correlated scalar/lateral subqueries with OneRowRelation. To allow more types of correlated scalar subqueries. Yes. This PR allows outer query column references in the SELECT cluase of a correlated scalar subquery. For example: ```sql SELECT (SELECT c1) FROM t; ``` Before this change: ``` org.apache.spark.sql.AnalysisException: Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses ``` After this change: ``` +------------------+ |scalarsubquery(c1)| +------------------+ |0 | |1 | +------------------+ ``` Added unit tests and SQL tests. Closes #33235 from allisonwang-db/spark-36028-outer-in-project. Authored-by: allisonwang-db Signed-off-by: Wenchen Fan (cherry picked from commit ca348e50a4edbd857ec86e4e9693fa4bcbab54b7) Signed-off-by: allisonwang-db --- .../sql/catalyst/analysis/CheckAnalysis.scala | 23 +++-- .../analysis/AnalysisErrorSuite.scala | 7 -- .../analysis/ResolveSubquerySuite.scala | 26 ++++- .../scalar-subquery-select.sql | 9 +- .../scalar-subquery-select.sql.out | 97 ++++++++++++++++++- 5 files changed, 144 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index e439085633a59..c1578483ca921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -725,9 +725,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { s"Filter/Aggregate/Project and a few commands: $plan") } } + // Validate to make sure the correlations appearing in the query are valid and + // allowed by spark. + checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true) case _: LateralSubquery => assert(plan.isInstanceOf[LateralJoin]) + // Validate to make sure the correlations appearing in the query are valid and + // allowed by spark. + checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true) case inSubqueryOrExistsSubquery => plan match { @@ -736,11 +742,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" + s" Filter/Join and a few commands: $plan") } + // Validate to make sure the correlations appearing in the query are valid and + // allowed by spark. + checkCorrelationsInSubquery(expr.plan) } - - // Validate to make sure the correlations appearing in the query are valid and - // allowed by spark. - checkCorrelationsInSubquery(expr.plan, isLateral = plan.isInstanceOf[LateralJoin]) } /** @@ -779,7 +784,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { * Validates to make sure the outer references appearing inside the subquery * are allowed. */ - private def checkCorrelationsInSubquery(sub: LogicalPlan, isLateral: Boolean = false): Unit = { + private def checkCorrelationsInSubquery( + sub: LogicalPlan, + isScalarOrLateral: Boolean = false): Unit = { // Validate that correlated aggregate expression do not contain a mixture // of outer and local references. def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { @@ -800,11 +807,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } // Check whether the logical plan node can host outer references. - // A `Project` can host outer references if it is inside a lateral subquery. - // Otherwise, only Filter can only outer references. + // A `Project` can host outer references if it is inside a scalar or a lateral subquery and + // DecorrelateInnerQuery is enabled. Otherwise, only Filter can only outer references. def canHostOuter(plan: LogicalPlan): Boolean = plan match { case _: Filter => true - case _: Project => isLateral + case _: Project => isScalarOrLateral && SQLConf.get.decorrelateInnerQueryEnabled case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 3ac9874f97206..6cda05360aea3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -824,13 +824,6 @@ class AnalysisErrorSuite extends AnalysisTest { Project(ScalarSubquery(t0.select(star("t1"))).as("sub") :: Nil, t1), "Scalar subquery must return only one column, but got 2" :: Nil) - // array(t1.*) in the subquery should be resolved into array(outer(t1.a), outer(t1.b)) - val array = CreateArray(Seq(star("t1"))) - assertAnalysisError( - Project(ScalarSubquery(t0.select(array)).as("sub") :: Nil, t1), - "Expressions referencing the outer query are not supported outside" + - " of WHERE/HAVING clauses" :: Nil) - // t2.* cannot be resolved and the error should be the initial analysis exception. assertAnalysisError( Project(ScalarSubquery(t0.select(star("t2"))).as("sub") :: Nil, t1), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 62ce863f62d8f..212f2b856d4b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference, ScalarSubquery} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ @@ -240,4 +240,28 @@ class ResolveSubquerySuite extends AnalysisTest { Inner, None) ) } + + test("SPARK-36028: resolve scalar subqueries with outer references in Project") { + // SELECT (SELECT a) FROM t1 + checkAnalysis( + Project(ScalarSubquery(t0.select('a)).as("sub") :: Nil, t1), + Project(ScalarSubquery(Project(OuterReference(a) :: Nil, t0), Seq(a)).as("sub") :: Nil, t1) + ) + // SELECT (SELECT a + b + c AS r FROM t2) FROM t1 + checkAnalysis( + Project(ScalarSubquery( + t2.select(('a + 'b + 'c).as("r"))).as("sub") :: Nil, t1), + Project(ScalarSubquery( + Project((OuterReference(a) + b + c).as("r") :: Nil, t2), Seq(a)).as("sub") :: Nil, t1) + ) + // SELECT (SELECT array(t1.*) AS arr) FROM t1 + checkAnalysis( + Project(ScalarSubquery(t0.select( + CreateArray(Seq(star("t1"))).as("arr")) + ).as("sub") :: Nil, t1.as("t1")), + Project(ScalarSubquery(Project( + CreateArray(Seq(OuterReference(a), OuterReference(b))).as("arr") :: Nil, t0 + ), Seq(a, b)).as("sub") :: Nil, t1) + ) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql index 936da959efabf..a76a010722090 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql @@ -137,4 +137,11 @@ SELECT t1a, (SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2, (SELECT sort_array(collect_set(t2d)) FROM t2 WHERE t2a = t1a) collect_set_t2, (SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2 -FROM t1; \ No newline at end of file +FROM t1; + +-- SPARK-36028: Allow Project to host outer references in scalar subqueries +SELECT t1c, (SELECT t1c) FROM t1; +SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1; +SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1; +SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1; +SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out index 68aad89e407a2..8fac940f8efd0 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 17 -- !query @@ -222,3 +222,98 @@ val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB9000000 val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000 + + +-- !query +SELECT t1c, (SELECT t1c) FROM t1 +-- !query schema +struct +-- !query output +12 12 +12 12 +16 16 +16 16 +16 16 +16 16 +8 8 +8 8 +NULL NULL +NULL NULL +NULL NULL +NULL NULL + + +-- !query +SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1 +-- !query schema +struct +-- !query output +12 NULL +12 NULL +16 NULL +16 NULL +16 NULL +16 NULL +8 8 +8 8 +NULL NULL +NULL NULL +NULL NULL +NULL NULL + + +-- !query +SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1 +-- !query schema +struct +-- !query output +12 10 22 +12 21 33 +16 19 35 +16 19 35 +16 19 35 +16 22 38 +8 10 18 +8 10 18 +NULL 12 NULL +NULL 19 NULL +NULL 19 NULL +NULL 25 NULL + + +-- !query +SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1 +-- !query schema +struct +-- !query output +12 12 +12 12 +16 16 +16 16 +16 16 +16 16 +8 8 +8 8 +NULL NULL +NULL NULL +NULL NULL +NULL NULL + + +-- !query +SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1 +-- !query schema +struct +-- !query output +val1a NULL +val1a NULL +val1a NULL +val1a NULL +val1b 36 +val1c 24 +val1d NULL +val1d NULL +val1d NULL +val1e 8 +val1e 8 +val1e 8