Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -762,9 +762,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
s"Filter/Aggregate/Project and a few commands: $plan")
}
}
// Validate to make sure the correlations appearing in the query are valid and
// allowed by spark.
checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true)

case _: LateralSubquery =>
assert(plan.isInstanceOf[LateralJoin])
// Validate to make sure the correlations appearing in the query are valid and
// allowed by spark.
checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true)

case inSubqueryOrExistsSubquery =>
plan match {
Expand All @@ -773,11 +779,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" +
s" Filter/Join and a few commands: $plan")
}
// Validate to make sure the correlations appearing in the query are valid and
// allowed by spark.
checkCorrelationsInSubquery(expr.plan)
}

// Validate to make sure the correlations appearing in the query are valid and
// allowed by spark.
checkCorrelationsInSubquery(expr.plan, isLateral = plan.isInstanceOf[LateralJoin])
}

/**
Expand Down Expand Up @@ -816,7 +821,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
* Validates to make sure the outer references appearing inside the subquery
* are allowed.
*/
private def checkCorrelationsInSubquery(sub: LogicalPlan, isLateral: Boolean = false): Unit = {
private def checkCorrelationsInSubquery(
sub: LogicalPlan,
isScalarOrLateral: Boolean = false): Unit = {
// Validate that correlated aggregate expression do not contain a mixture
// of outer and local references.
def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = {
Expand All @@ -837,11 +844,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
}

// Check whether the logical plan node can host outer references.
// A `Project` can host outer references if it is inside a lateral subquery.
// Otherwise, only Filter can only outer references.
// A `Project` can host outer references if it is inside a scalar or a lateral subquery and
// DecorrelateInnerQuery is enabled. Otherwise, only Filter can only outer references.
def canHostOuter(plan: LogicalPlan): Boolean = plan match {
case _: Filter => true
case _: Project => isLateral
case _: Project => isScalarOrLateral && SQLConf.get.decorrelateInnerQueryEnabled
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -824,13 +824,6 @@ class AnalysisErrorSuite extends AnalysisTest {
Project(ScalarSubquery(t0.select(star("t1"))).as("sub") :: Nil, t1),
"Scalar subquery must return only one column, but got 2" :: Nil)

// array(t1.*) in the subquery should be resolved into array(outer(t1.a), outer(t1.b))
val array = CreateArray(Seq(star("t1")))
assertAnalysisError(
Project(ScalarSubquery(t0.select(array)).as("sub") :: Nil, t1),
"Expressions referencing the outer query are not supported outside" +
" of WHERE/HAVING clauses" :: Nil)

// t2.* cannot be resolved and the error should be the initial analysis exception.
assertAnalysisError(
Project(ScalarSubquery(t0.select(star("t2"))).as("sub") :: Nil, t1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference}
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -240,4 +240,28 @@ class ResolveSubquerySuite extends AnalysisTest {
Inner, None)
)
}

test("SPARK-36028: resolve scalar subqueries with outer references in Project") {
// SELECT (SELECT a) FROM t1
checkAnalysis(
Project(ScalarSubquery(t0.select('a)).as("sub") :: Nil, t1),
Project(ScalarSubquery(Project(OuterReference(a) :: Nil, t0), Seq(a)).as("sub") :: Nil, t1)
)
// SELECT (SELECT a + b + c AS r FROM t2) FROM t1
checkAnalysis(
Project(ScalarSubquery(
t2.select(('a + 'b + 'c).as("r"))).as("sub") :: Nil, t1),
Project(ScalarSubquery(
Project((OuterReference(a) + b + c).as("r") :: Nil, t2), Seq(a)).as("sub") :: Nil, t1)
)
// SELECT (SELECT array(t1.*) AS arr) FROM t1
checkAnalysis(
Project(ScalarSubquery(t0.select(
CreateArray(Seq(star("t1"))).as("arr"))
).as("sub") :: Nil, t1.as("t1")),
Project(ScalarSubquery(Project(
CreateArray(Seq(OuterReference(a), OuterReference(b))).as("arr") :: Nil, t0
), Seq(a, b)).as("sub") :: Nil, t1)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,11 @@ SELECT t1a,
(SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
(SELECT sort_array(collect_set(t2d)) FROM t2 WHERE t2a = t1a) collect_set_t2,
(SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2
FROM t1;
FROM t1;

-- SPARK-36028: Allow Project to host outer references in scalar subqueries
SELECT t1c, (SELECT t1c) FROM t1;
SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1;
SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1;
SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1;
SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 12
-- Number of queries: 17


-- !query
Expand Down Expand Up @@ -222,3 +222,98 @@ val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB9000000
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000


-- !query
SELECT t1c, (SELECT t1c) FROM t1
-- !query schema
struct<t1c:int,scalarsubquery(t1c):int>
-- !query output
12 12
12 12
16 16
16 16
16 16
16 16
8 8
8 8
NULL NULL
NULL NULL
NULL NULL
NULL NULL


-- !query
SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1
-- !query schema
struct<t1c:int,scalarsubquery(t1c, t1c):int>
-- !query output
12 NULL
12 NULL
16 NULL
16 NULL
16 NULL
16 NULL
8 8
8 8
NULL NULL
NULL NULL
NULL NULL
NULL NULL


-- !query
SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1
-- !query schema
struct<t1c:int,t1d:bigint,scalarsubquery(t1c, t1d):bigint>
-- !query output
12 10 22
12 21 33
16 19 35
16 19 35
16 19 35
16 22 38
8 10 18
8 10 18
NULL 12 NULL
NULL 19 NULL
NULL 19 NULL
NULL 25 NULL


-- !query
SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1
-- !query schema
struct<t1c:int,scalarsubquery(t1c):bigint>
-- !query output
12 12
12 12
16 16
16 16
16 16
16 16
8 8
8 8
NULL NULL
NULL NULL
NULL NULL
NULL NULL


-- !query
SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1
-- !query schema
struct<t1a:string,scalarsubquery(t1a):bigint>
-- !query output
val1a NULL
val1a NULL
val1a NULL
val1a NULL
val1b 36
val1c 24
val1d NULL
val1d NULL
val1d NULL
val1e 8
val1e 8
val1e 8