Skip to content

Commit c9bd987

Browse files
committed
[SPARK-13306][SQL] Addendum to uncorrelated scalar subquery
1 parent 7925071 commit c9bd987

File tree

5 files changed

+47
-51
lines changed

5 files changed

+47
-51
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,12 @@ class Analyzer(
122122
}
123123
substituted.getOrElse(u)
124124
case other =>
125-
// This can't be done in ResolveSubquery because that does not know the CTE.
125+
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
126126
other transformExpressions {
127127
case e: SubqueryExpression =>
128128
e.withNewPlan(substituteCTE(e.query, cteRelations))
129129
}
130130
}
131-
132131
}
133132
}
134133

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,3 @@ case class Literal protected (value: Any, dataType: DataType)
272272
case _ => value.toString
273273
}
274274
}
275-
276-
// TODO: Specialize
277-
case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
278-
extends LeafExpression with CodegenFallback {
279-
280-
def update(expression: Expression, input: InternalRow): Unit = {
281-
value = expression.eval(input)
282-
}
283-
284-
override def eval(input: InternalRow): Any = value
285-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
4646
* populated by the query planning infrastructure.
4747
*/
4848
@transient
49-
protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null)
49+
protected[spark] final val sqlContext = SQLContext.getActive().orNull
5050

5151
protected def sparkContext = sqlContext.sparkContext
5252

@@ -120,44 +120,49 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
120120
}
121121
}
122122

123-
// All the subqueries and their Future of results.
124-
@transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()
123+
/**
124+
* List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node.
125+
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
126+
*/
127+
@transient
128+
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]
125129

126130
/**
127-
* Collects all the subqueries and create a Future to take the first two rows of them.
131+
* Finds scalar subquery expressions in this plan node and starts evaluating them.
132+
* The list of subqueries are added to [[subqueryResults]].
128133
*/
129134
protected def prepareSubqueries(): Unit = {
130135
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
131136
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
132137
val futureResult = Future {
133-
// We only need the first row, try to take two rows so we can throw an exception if there
134-
// are more than one rows returned.
138+
// Each subquery should return only one row (and one column). We take two here and throws
139+
// an exception later if the number of rows is greater than one.
135140
e.executedPlan.executeTake(2)
136141
}(SparkPlan.subqueryExecutionContext)
137-
queryResults += e -> futureResult
142+
subqueryResults += e -> futureResult
138143
}
139144
}
140145

141146
/**
142-
* Waits for all the subqueries to finish and updates the results.
147+
* Blocks the thread until all subqueries finish evaluation and update the results.
143148
*/
144149
protected def waitForSubqueries(): Unit = {
145150
// fill in the result of subqueries
146-
queryResults.foreach {
147-
case (e, futureResult) =>
148-
val rows = Await.result(futureResult, Duration.Inf)
149-
if (rows.length > 1) {
150-
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
151-
}
152-
if (rows.length == 1) {
153-
assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column")
154-
e.updateResult(rows(0).get(0, e.dataType))
155-
} else {
156-
// There is no rows returned, the result should be null.
157-
e.updateResult(null)
158-
}
151+
subqueryResults.foreach { case (e, futureResult) =>
152+
val rows = Await.result(futureResult, Duration.Inf)
153+
if (rows.length > 1) {
154+
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
155+
}
156+
if (rows.length == 1) {
157+
assert(rows(0).numFields == 1,
158+
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
159+
e.updateResult(rows(0).get(0, e.dataType))
160+
} else {
161+
// If there is no rows returned, the result should be null.
162+
e.updateResult(null)
163+
}
159164
}
160-
queryResults.clear()
165+
subqueryResults.clear()
161166
}
162167

163168
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ case class ScalarSubquery(
6262
/**
6363
* Convert the subquery from logical plan into executed plan.
6464
*/
65-
private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
65+
case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
6666
def apply(plan: SparkPlan): SparkPlan = {
6767
plan.transformAllExpressions {
6868
case subquery: expressions.ScalarSubquery =>

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,36 +36,39 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
3636
sql("select (select (select 1) + 1) + 1").collect()
3737
}
3838

39-
// more than one columns
40-
val error = intercept[AnalysisException] {
41-
sql("select (select 1, 2) as b").collect()
42-
}
43-
assert(error.message contains "Scalar subquery must return only one column, but got 2")
44-
45-
// more than one rows
46-
val error2 = intercept[RuntimeException] {
47-
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
48-
}
49-
assert(error2.getMessage contains
50-
"more than one row returned by a subquery used as an expression")
51-
5239
// string type
5340
assertResult(Array(Row("s"))) {
5441
sql("select (select 's' as s) as b").collect()
5542
}
43+
}
5644

57-
// zero rows
45+
test("uncorrelated scalar subquery should return null if there is 0 rows") {
5846
assertResult(Array(Row(null))) {
5947
sql("select (select 's' as s limit 0) as b").collect()
6048
}
6149
}
6250

51+
test("analysis error when the number of columns is not 1") {
52+
val error = intercept[AnalysisException] {
53+
sql("select (select 1, 2) as b").collect()
54+
}
55+
assert(error.message.contains("Scalar subquery must return only one column, but got 2"))
56+
}
57+
58+
test("runtime error when the number of rows is greater than 1") {
59+
val error2 = intercept[RuntimeException] {
60+
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
61+
}
62+
assert(error2.getMessage.contains(
63+
"more than one row returned by a subquery used as an expression"))
64+
}
65+
6366
test("uncorrelated scalar subquery on testData") {
6467
// initialize test Data
6568
testData
6669

6770
assertResult(Array(Row(5))) {
68-
sql("select (select key from testData where key > 3 limit 1) + 1").collect()
71+
sql("select (select key from testData where key > 3 order by key limit 1) + 1").collect()
6972
}
7073

7174
assertResult(Array(Row(-100))) {

0 commit comments

Comments
 (0)