Skip to content

Commit af441dd

Browse files
committed
[SPARK-13306][SQL] Addendum to uncorrelated scalar subquery
## What changes were proposed in this pull request? This pull request fixes some minor issues (documentation, test flakiness, test organization) with #11190, which was merged earlier tonight. ## How was the this patch tested? unit tests. Author: Reynold Xin <[email protected]> Closes #11285 from rxin/subquery.
1 parent 0947f09 commit af441dd

File tree

6 files changed

+61
-70
lines changed

6 files changed

+61
-70
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,12 @@ class Analyzer(
123123
}
124124
substituted.getOrElse(u)
125125
case other =>
126-
// This can't be done in ResolveSubquery because that does not know the CTE.
126+
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
127127
other transformExpressions {
128128
case e: SubqueryExpression =>
129129
e.withNewPlan(substituteCTE(e.query, cteRelations))
130130
}
131131
}
132-
133132
}
134133
}
135134

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,3 @@ case class Literal protected (value: Any, dataType: DataType)
255255
case _ => value.toString
256256
}
257257
}
258-
259-
// TODO: Specialize
260-
case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
261-
extends LeafExpression with CodegenFallback {
262-
263-
def update(expression: Expression, input: InternalRow): Unit = {
264-
value = expression.eval(input)
265-
}
266-
267-
override def eval(input: InternalRow): Any = value
268-
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@ abstract class SubqueryExpression extends LeafExpression {
4545
}
4646

4747
/**
48-
* A subquery that will return only one row and one column.
49-
*
50-
* This will be converted into [[execution.ScalarSubquery]] during physical planning.
48+
* A subquery that will return only one row and one column. This will be converted into a physical
49+
* scalar subquery during planning.
5150
*
5251
* Note: `exprId` is used to have unique name in explain string output.
5352
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
4646
* populated by the query planning infrastructure.
4747
*/
4848
@transient
49-
protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null)
49+
protected[spark] final val sqlContext = SQLContext.getActive().orNull
5050

5151
protected def sparkContext = sqlContext.sparkContext
5252

@@ -120,44 +120,49 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
120120
}
121121
}
122122

123-
// All the subqueries and their Future of results.
124-
@transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()
123+
/**
124+
* List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node.
125+
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
126+
*/
127+
@transient
128+
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]
125129

126130
/**
127-
* Collects all the subqueries and create a Future to take the first two rows of them.
131+
* Finds scalar subquery expressions in this plan node and starts evaluating them.
132+
* The list of subqueries are added to [[subqueryResults]].
128133
*/
129134
protected def prepareSubqueries(): Unit = {
130135
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
131136
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
132137
val futureResult = Future {
133-
// We only need the first row, try to take two rows so we can throw an exception if there
134-
// are more than one rows returned.
138+
// Each subquery should return only one row (and one column). We take two here and throws
139+
// an exception later if the number of rows is greater than one.
135140
e.executedPlan.executeTake(2)
136141
}(SparkPlan.subqueryExecutionContext)
137-
queryResults += e -> futureResult
142+
subqueryResults += e -> futureResult
138143
}
139144
}
140145

141146
/**
142-
* Waits for all the subqueries to finish and updates the results.
147+
* Blocks the thread until all subqueries finish evaluation and update the results.
143148
*/
144149
protected def waitForSubqueries(): Unit = {
145150
// fill in the result of subqueries
146-
queryResults.foreach {
147-
case (e, futureResult) =>
148-
val rows = Await.result(futureResult, Duration.Inf)
149-
if (rows.length > 1) {
150-
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
151-
}
152-
if (rows.length == 1) {
153-
assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column")
154-
e.updateResult(rows(0).get(0, e.dataType))
155-
} else {
156-
// There is no rows returned, the result should be null.
157-
e.updateResult(null)
158-
}
151+
subqueryResults.foreach { case (e, futureResult) =>
152+
val rows = Await.result(futureResult, Duration.Inf)
153+
if (rows.length > 1) {
154+
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
155+
}
156+
if (rows.length == 1) {
157+
assert(rows(0).numFields == 1,
158+
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
159+
e.updateResult(rows(0).get(0, e.dataType))
160+
} else {
161+
// If there is no rows returned, the result should be null.
162+
e.updateResult(null)
163+
}
159164
}
160-
queryResults.clear()
165+
subqueryResults.clear()
161166
}
162167

163168
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ case class ScalarSubquery(
6262
/**
6363
* Convert the subquery from logical plan into executed plan.
6464
*/
65-
private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
65+
case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
6666
def apply(plan: SparkPlan): SparkPlan = {
6767
plan.transformAllExpressions {
6868
case subquery: expressions.ScalarSubquery =>

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,65 +20,64 @@ package org.apache.spark.sql
2020
import org.apache.spark.sql.test.SharedSQLContext
2121

2222
class SubquerySuite extends QueryTest with SharedSQLContext {
23+
import testImplicits._
2324

2425
test("simple uncorrelated scalar subquery") {
2526
assertResult(Array(Row(1))) {
2627
sql("select (select 1 as b) as b").collect()
2728
}
2829

29-
assertResult(Array(Row(1))) {
30-
sql("with t2 as (select 1 as b, 2 as c) " +
31-
"select a from (select 1 as a union all select 2 as a) t " +
32-
"where a = (select max(b) from t2) ").collect()
33-
}
34-
3530
assertResult(Array(Row(3))) {
3631
sql("select (select (select 1) + 1) + 1").collect()
3732
}
3833

39-
// more than one columns
40-
val error = intercept[AnalysisException] {
41-
sql("select (select 1, 2) as b").collect()
42-
}
43-
assert(error.message contains "Scalar subquery must return only one column, but got 2")
44-
45-
// more than one rows
46-
val error2 = intercept[RuntimeException] {
47-
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
48-
}
49-
assert(error2.getMessage contains
50-
"more than one row returned by a subquery used as an expression")
51-
5234
// string type
5335
assertResult(Array(Row("s"))) {
5436
sql("select (select 's' as s) as b").collect()
5537
}
38+
}
5639

57-
// zero rows
40+
test("uncorrelated scalar subquery in CTE") {
41+
assertResult(Array(Row(1))) {
42+
sql("with t2 as (select 1 as b, 2 as c) " +
43+
"select a from (select 1 as a union all select 2 as a) t " +
44+
"where a = (select max(b) from t2) ").collect()
45+
}
46+
}
47+
48+
test("uncorrelated scalar subquery should return null if there is 0 rows") {
5849
assertResult(Array(Row(null))) {
5950
sql("select (select 's' as s limit 0) as b").collect()
6051
}
6152
}
6253

63-
test("uncorrelated scalar subquery on testData") {
64-
// initialize test Data
65-
testData
54+
test("runtime error when the number of rows is greater than 1") {
55+
val error2 = intercept[RuntimeException] {
56+
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
57+
}
58+
assert(error2.getMessage.contains(
59+
"more than one row returned by a subquery used as an expression"))
60+
}
61+
62+
test("uncorrelated scalar subquery on a DataFrame generated query") {
63+
val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value")
64+
df.registerTempTable("subqueryData")
6665

67-
assertResult(Array(Row(5))) {
68-
sql("select (select key from testData where key > 3 limit 1) + 1").collect()
66+
assertResult(Array(Row(4))) {
67+
sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1").collect()
6968
}
7069

71-
assertResult(Array(Row(-100))) {
72-
sql("select -(select max(key) from testData)").collect()
70+
assertResult(Array(Row(-3))) {
71+
sql("select -(select max(key) from subqueryData)").collect()
7372
}
7473

7574
assertResult(Array(Row(null))) {
76-
sql("select (select value from testData limit 0)").collect()
75+
sql("select (select value from subqueryData limit 0)").collect()
7776
}
7877

79-
assertResult(Array(Row("99"))) {
80-
sql("select (select min(value) from testData" +
81-
" where key = (select max(key) from testData) - 1)").collect()
78+
assertResult(Array(Row("two"))) {
79+
sql("select (select min(value) from subqueryData" +
80+
" where key = (select max(key) from subqueryData) - 1)").collect()
8281
}
8382
}
8483
}

0 commit comments

Comments
 (0)