Skip to content

Commit 77ea002

Browse files
committed
fix bug
1 parent b1914de commit 77ea002

File tree

4 files changed

+82
-72
lines changed

4 files changed

+82
-72
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ package org.apache.spark.sql.execution
2020
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
2121

2222
import scala.collection.mutable.ArrayBuffer
23-
import scala.concurrent.{ExecutionContext, Future}
24-
import scala.concurrent.duration._
23+
import scala.concurrent.ExecutionContext
2524

2625
import org.apache.spark.{broadcast, SparkEnv}
2726
import org.apache.spark.internal.Logging
@@ -138,51 +137,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
138137
}
139138

140139
/**
141-
* List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node.
140+
* List of uncorrelated scalar subquery for this plan node.
142141
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
143142
*/
144143
@transient
145-
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]
144+
private var allSubqueries = new ArrayBuffer[ScalarSubquery]
146145

147146
/**
148147
* Finds scalar subquery expressions in this plan node and starts evaluating them.
149-
* The list of subqueries are added to [[subqueryResults]].
148+
* The list of subqueries are added to [[allSubqueries]].
150149
*/
151150
protected def prepareSubqueries(): Unit = {
152-
val allSubqueries = expressions.flatMap(_.collect {
153-
case e: ScalarSubquery if !e.isExecuted => e
154-
}).distinct
155-
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
156-
e.updateExecutedState()
157-
val futureResult = Future {
158-
// Each subquery should return only one row (and one column). We take two here and throws
159-
// an exception later if the number of rows is greater than one.
160-
e.executedPlan.executeTake(2)
161-
}(SparkPlan.subqueryExecutionContext)
162-
subqueryResults += e -> futureResult
151+
expressions.flatMap(_.collect { case e: ScalarSubquery => e }).distinct.foreach { e =>
152+
e.submitSubqueryEvaluated()
153+
allSubqueries += e
163154
}
164155
}
165156

166157
/**
167-
* Blocks the thread until all subqueries finish evaluation and update the results.
158+
* Blocks the thread until all subqueries finish evaluation.
168159
*/
169160
protected def waitForSubqueries(): Unit = synchronized {
170-
// fill in the result of subqueries
171-
subqueryResults.foreach { case (e, futureResult) =>
172-
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
173-
if (rows.length > 1) {
174-
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
175-
}
176-
if (rows.length == 1) {
177-
assert(rows(0).numFields == 1,
178-
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
179-
e.updateResult(rows(0).get(0, e.dataType))
180-
} else {
181-
// If there is no rows returned, the result should be null.
182-
e.updateResult(null)
183-
}
161+
allSubqueries.foreach { e =>
162+
e.awaitSubqueryResult()
184163
}
185-
subqueryResults.clear()
186164
}
187165

188166
/**
@@ -393,11 +371,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
393371
}
394372
}
395373

396-
object SparkPlan {
397-
private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService(
398-
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
399-
}
400-
401374
private[sql] trait LeafExecNode extends SparkPlan {
402375
override def children: Seq[SparkPlan] = Nil
403376
override def producedAttributes: AttributeSet = outputSet

sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
1919

2020
import scala.collection.mutable
2121
import scala.collection.mutable.ArrayBuffer
22+
import scala.concurrent.{ExecutionContext, Future}
23+
import scala.concurrent.duration.Duration
2224

2325
import org.apache.spark.sql.SparkSession
2426
import org.apache.spark.sql.catalyst.expressions
@@ -28,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
2830
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2931
import org.apache.spark.sql.catalyst.rules.Rule
3032
import org.apache.spark.sql.types.{DataType, StructType}
33+
import org.apache.spark.util.ThreadUtils
3134

3235
/**
3336
* A subquery that will return only one row and one column.
@@ -53,15 +56,10 @@ case class ScalarSubquery(
5356
// the first column in first row from `query`.
5457
@volatile private var result: Any = null
5558
@volatile private var updated: Boolean = false
56-
@volatile private var executed: Boolean = false
59+
@volatile private var evaluated: Boolean = false
60+
@volatile private var futureResult: Future[Array[InternalRow]] = _
5761

58-
def isExecuted: Boolean = executed
59-
60-
def updateExecutedState() : Unit = {
61-
executed = true
62-
}
63-
64-
def updateResult(v: Any): Unit = {
62+
private def updateResult(v: Any): Unit = {
6563
result = v
6664
updated = true
6765
}
@@ -76,6 +74,40 @@ case class ScalarSubquery(
7674
Literal.create(result, dataType).doGenCode(ctx, ev)
7775
}
7876

77+
/**
78+
* Submit the subquery to be evaluated. No need to do if it has been evaluated before.
79+
*/
80+
def submitSubqueryEvaluated(): Unit = synchronized {
81+
if (!evaluated) {
82+
futureResult = Future {
83+
// Each subquery should return only one row (and one column). We take two here and throws
84+
// an exception later if the number of rows is greater than one.
85+
executedPlan.executeTake(2)
86+
}(ScalarSubquery.subqueryExecutionContext)
87+
evaluated = true
88+
}
89+
}
90+
91+
/**
92+
* Blocks the thread until the evaluation of subquery has been finished.
93+
*/
94+
def awaitSubqueryResult(): Unit = synchronized {
95+
if (!updated) {
96+
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
97+
if (rows.length > 1) {
98+
sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}")
99+
}
100+
if (rows.length == 1) {
101+
assert(rows(0).numFields == 1,
102+
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
103+
updateResult(rows(0).get(0, dataType))
104+
} else {
105+
// If there is no rows returned, the result should be null.
106+
updateResult(null)
107+
}
108+
}
109+
}
110+
79111
override def equals(o: Any): Boolean = o match {
80112
case other: ScalarSubquery => this.eq(other)
81113
case _ => false
@@ -84,6 +116,11 @@ case class ScalarSubquery(
84116
override def hashCode: Int = exprId.hashCode()
85117
}
86118

119+
object ScalarSubquery {
120+
private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService(
121+
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
122+
}
123+
87124
/**
88125
* A wrapper for reused uncorrelated ScalarSubquery to avoid the re-computing for subqueries with
89126
* the same "canonical" logical plan in a query, because uncorrelated subqueries with the same

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,31 +2896,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
28962896
sql(s"SELECT '$literal' AS DUMMY"),
28972897
Row(s"$expected") :: Nil)
28982898
}
2899-
2900-
test("SPARK-16456: Reuse the uncorrelated scalar subqueries with the same logical plan") {
2901-
withTempTable("t1", "t2", "t3") {
2902-
val df = (1 to 3).map(i => (i, i)).toDF("key", "value")
2903-
df.createOrReplaceTempView("t1")
2904-
df.createOrReplaceTempView("t2")
2905-
df.createOrReplaceTempView("t3")
2906-
checkAnswer(
2907-
sql(
2908-
"""
2909-
|WITH max_test AS
2910-
|(
2911-
| SELECT max(key) as max_key FROM t1
2912-
|),
2913-
|max_test2 AS
2914-
|(
2915-
| SELECT max(key) as max_key FROM t1
2916-
|)
2917-
|SELECT key FROM t2
2918-
|WHERE key = (SELECT max_key FROM max_test) and value = (SELECT max_key FROM max_test)
2919-
|UNION ALL
2920-
|SELECT key FROM t3
2921-
|WHERE key = (SELECT max_key FROM max_test) and value = (SELECT max_key FROM max_test2)
2922-
""".stripMargin
2923-
), Row(3) :: Row(3) :: Nil)
2924-
}
2925-
}
29262899
}

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,4 +571,31 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
571571
Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
572572
Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
573573
}
574+
575+
test("SPARK-16456: Reuse the uncorrelated scalar subqueries with the same logical plan") {
576+
withTempTable("t1", "t2", "t3") {
577+
val df = (1 to 3).map(i => (i, i)).toDF("key", "value")
578+
df.createOrReplaceTempView("t1")
579+
df.createOrReplaceTempView("t2")
580+
df.createOrReplaceTempView("t3")
581+
checkAnswer(
582+
sql(
583+
"""
584+
|WITH max_test AS
585+
|(
586+
| SELECT max(key) as max_key FROM t1
587+
|),
588+
|max_test2 AS
589+
|(
590+
| SELECT max(key) as max_key FROM t1
591+
|)
592+
|SELECT key FROM t2
593+
|WHERE key = (SELECT max_key FROM max_test) and value = (SELECT max_key FROM max_test)
594+
|UNION ALL
595+
|SELECT key FROM t3
596+
|WHERE key = (SELECT max_key FROM max_test) and value = (SELECT max_key FROM max_test2)
597+
""".stripMargin
598+
), Row(3) :: Row(3) :: Nil)
599+
}
600+
}
574601
}

0 commit comments

Comments
 (0)