Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class Analyzer(
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
case e: SubqueryExpression =>
e.withNewPlan(substituteCTE(e.query, cteRelations))
e.withNewPlan(substituteCTE(e.plan, cteRelations))
}
}
}
Expand Down Expand Up @@ -1091,7 +1091,7 @@ class Analyzer(
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
// Step 1: Resolve the outer expressions.
var previous: LogicalPlan = null
var current = e.query
var current = e.plan
do {
// Try to resolve the subquery plan using the regular analyzer.
previous = current
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._

/**
* An interface for subquery that is used in expressions.
* An interface for expressions that contain a [[QueryPlan]].
*/
abstract class SubqueryExpression extends Expression {
abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
/** The id of the subquery expression. */
def exprId: ExprId

/** The logical plan of the query. */
def query: LogicalPlan
/** The plan being wrapped in the query. */
def plan: T

/**
* Either a logical plan or a physical plan. The generated tree string (explain output) uses this
* field to explain the subquery.
*/
def plan: QueryPlan[_]

/** Updates the query with new logical plan. */
def withNewPlan(plan: LogicalPlan): SubqueryExpression
/** Updates the expression with a new plan. */
def withNewPlan(plan: T): PlanExpression[T]

protected def conditionString: String = children.mkString("[", " && ", "]")
}

/**
* A base interface for expressions that contain a [[LogicalPlan]].
*/
abstract class SubqueryExpression extends PlanExpression[LogicalPlan] {
override def withNewPlan(plan: LogicalPlan): SubqueryExpression
}

object SubqueryExpression {
def hasCorrelatedSubquery(e: Expression): Boolean = {
e.find {
Expand All @@ -60,20 +60,19 @@ object SubqueryExpression {
* Note: `exprId` is used to have a unique name in explain string output.
*/
case class ScalarSubquery(
query: LogicalPlan,
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Unevaluable {
override lazy val resolved: Boolean = childrenResolved && query.resolved
override lazy val resolved: Boolean = childrenResolved && plan.resolved
override lazy val references: AttributeSet = {
if (query.resolved) super.references -- query.outputSet
if (plan.resolved) super.references -- plan.outputSet
else super.references
}
override def dataType: DataType = query.schema.fields.head.dataType
override def dataType: DataType = plan.schema.fields.head.dataType
override def foldable: Boolean = false
override def nullable: Boolean = true
override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan)
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan)
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
}

Expand All @@ -92,19 +91,18 @@ object ScalarSubquery {
* be rewritten into a left semi/anti join during analysis.
*/
case class PredicateSubquery(
query: LogicalPlan,
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
nullAware: Boolean = false,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Predicate with Unevaluable {
override lazy val resolved = childrenResolved && query.resolved
override lazy val references: AttributeSet = super.references -- query.outputSet
override lazy val resolved = childrenResolved && plan.resolved
override lazy val references: AttributeSet = super.references -- plan.outputSet
override def nullable: Boolean = nullAware
override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan)
override def semanticEquals(o: Expression): Boolean = o match {
case p: PredicateSubquery =>
query.sameResult(p.query) && nullAware == p.nullAware &&
plan.sameResult(p.plan) && nullAware == p.nullAware &&
children.length == p.children.length &&
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
case _ => false
Expand Down Expand Up @@ -146,14 +144,13 @@ object PredicateSubquery {
* FROM b)
* }}}
*/
case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Unevaluable {
override lazy val resolved = false
override def children: Seq[Expression] = Seq.empty
override def dataType: DataType = ArrayType(NullType)
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(query = plan)
override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
override def toString: String = s"list#${exprId.id}"
}

Expand All @@ -168,12 +165,11 @@ case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExp
* WHERE b.id = a.id)
* }}}
*/
case class Exists(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Predicate with Unevaluable {
override lazy val resolved = false
override def children: Seq[Expression] = Seq.empty
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): Exists = copy(query = plan)
override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
override def toString: String = s"exists#${exprId.id}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
object OptimizeSubqueries extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case s: SubqueryExpression =>
s.withNewPlan(Optimizer.this.execute(s.query))
s.withNewPlan(Optimizer.this.execute(s.plan))
}
}
}
Expand Down Expand Up @@ -1814,7 +1814,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
val newExpression = expression transform {
case s: ScalarSubquery if s.children.nonEmpty =>
subqueries += s
s.query.output.head
s.plan.output.head
}
newExpression.asInstanceOf[E]
}
Expand Down Expand Up @@ -2029,7 +2029,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e)
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
}
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* All the subqueries of current plan.
*/
def subqueries: Seq[PlanType] = {
expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]})
expressions.flatMap(_.collect {
case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType]
})
}

override protected def innerChildren: Seq[QueryPlan[_]] = subqueries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class SQLBuilder private (
try {
val replaced = finalPlan.transformAllExpressions {
case s: SubqueryExpression =>
val query = new SQLBuilder(s.query, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL
val query = new SQLBuilder(s.plan, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL
val sql = s match {
case _: ListQuery => query
case _: Exists => s"EXISTS($query)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,16 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}

/**
* The base class for subquery that is used in SparkPlan.
*/
trait ExecSubqueryExpression extends SubqueryExpression {

val executedPlan: SubqueryExec
def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression

// does not have logical plan
override def query: LogicalPlan = throw new UnsupportedOperationException
override def withNewPlan(plan: LogicalPlan): SubqueryExpression =
throw new UnsupportedOperationException

override def plan: SparkPlan = executedPlan

abstract class ExecSubqueryExpression extends PlanExpression[SubqueryExec] {
/**
* Fill the expression with collected result from executed plan.
*/
Expand All @@ -56,30 +44,29 @@ trait ExecSubqueryExpression extends SubqueryExpression {
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
*/
case class ScalarSubquery(
executedPlan: SubqueryExec,
plan: SubqueryExec,
exprId: ExprId)
extends ExecSubqueryExpression {

override def dataType: DataType = executedPlan.schema.fields.head.dataType
override def dataType: DataType = plan.schema.fields.head.dataType
override def children: Seq[Expression] = Nil
override def nullable: Boolean = true
override def toString: String = executedPlan.simpleString

def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)
override def toString: String = plan.simpleString
override def withNewPlan(query: SubqueryExec): ScalarSubquery = copy(plan = query)

override def semanticEquals(other: Expression): Boolean = other match {
case s: ScalarSubquery => executedPlan.sameResult(executedPlan)
case s: ScalarSubquery => plan.sameResult(s.plan)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a small bug in sematicEquals.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

case _ => false
}

// the first column in first row from `query`.
@volatile private var result: Any = null
@volatile private var result: Any = _
@volatile private var updated: Boolean = false

def updateResult(): Unit = {
val rows = plan.executeCollect()
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}")
sys.error(s"more than one row returned by a subquery used as an expression:\n$plan")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
Expand Down Expand Up @@ -108,21 +95,19 @@ case class ScalarSubquery(
*/
case class InSubquery(
child: Expression,
executedPlan: SubqueryExec,
plan: SubqueryExec,
exprId: ExprId,
private var result: Array[Any] = null,
private var updated: Boolean = false) extends ExecSubqueryExpression {

override def dataType: DataType = BooleanType
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = child.nullable
override def toString: String = s"$child IN ${executedPlan.name}"

def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)
override def toString: String = s"$child IN ${plan.name}"
override def withNewPlan(plan: SubqueryExec): InSubquery = copy(plan = plan)

override def semanticEquals(other: Expression): Boolean = other match {
case in: InSubquery => child.semanticEquals(in.child) &&
executedPlan.sameResult(in.executedPlan)
case in: InSubquery => child.semanticEquals(in.child) && plan.sameResult(in.plan)
case _ => false
}

Expand Down Expand Up @@ -159,8 +144,8 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
ScalarSubquery(
SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan),
subquery.exprId)
case expressions.PredicateSubquery(plan, Seq(e: Expression), _, exprId) =>
val executedPlan = new QueryExecution(sparkSession, plan).executedPlan
case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) =>
val executedPlan = new QueryExecution(sparkSession, query).executedPlan
InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId)
}
}
Expand All @@ -184,9 +169,9 @@ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {
val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
val sameResult = sameSchema.find(_.sameResult(sub.plan))
if (sameResult.isDefined) {
sub.withExecutedPlan(sameResult.get)
sub.withNewPlan(sameResult.get)
} else {
sameSchema += sub.executedPlan
sameSchema += sub.plan
sub
}
}
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ abstract class QueryTest extends PlanTest {
p.expressions.foreach {
_.foreach {
case s: SubqueryExpression =>
s.query.foreach(collectData)
s.plan.foreach(collectData)
case _ =>
}
}
Expand Down Expand Up @@ -334,7 +334,7 @@ abstract class QueryTest extends PlanTest {
case p =>
p.transformExpressions {
case s: SubqueryExpression =>
s.withNewPlan(s.query.transformDown(renormalize))
s.withNewPlan(s.plan.transformDown(renormalize))
}
}
val normalized2 = jsonBackPlan.transformDown(renormalize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Benchmark

/**
Expand Down