From c78b0f07d739a40ea0d6bbb9ababad016931a274 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 9 Jul 2021 11:05:30 -0700 Subject: [PATCH 1/5] optimize --- .../sql/catalyst/expressions/subquery.scala | 6 ++- .../optimizer/DecorrelateInnerQuery.scala | 27 ++++++++-- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/subquery.scala | 41 ++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 22 ++++++++ .../apache/spark/sql/internal/SQLConf.scala | 10 ++++ .../org/apache/spark/sql/SubquerySuite.scala | 53 ++++++++++++++++++- 7 files changed, 153 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index d157330142977..0c7452a37d54a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -126,13 +126,15 @@ object SubExprUtils extends PredicateHelper { /** * Returns an expression after removing the OuterReference shell. */ - def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r } + def stripOuterReference[E <: Expression](e: E): E = { + e.transform { case OuterReference(r) => r }.asInstanceOf[E] + } /** * Returns the list of expressions after removing the OuterReference shell from each of * the expression. */ - def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference) + def stripOuterReferences[E <: Expression](e: Seq[E]): Seq[E] = e.map(stripOuterReference) /** * Returns the logical plan after removing the OuterReference shell from all the expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index f30dd9949f64f..f060f6c0b4b60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -156,6 +156,23 @@ object DecorrelateInnerQuery extends PredicateHelper { expressions.map(replaceOuterReference(_, outerReferenceMap)) } + /** + * Replace all outer references in the given named expressions and keep the output + * attributes unchanged. + */ + private def replaceOuterInNamedExpressions( + expressions: Seq[NamedExpression], + outerReferenceMap: Map[Attribute, Attribute]): Seq[NamedExpression] = { + expressions.map { expr => + val newExpr = replaceOuterReference(expr, outerReferenceMap) + if (!newExpr.toAttribute.semanticEquals(expr.toAttribute)) { + Alias(newExpr, expr.name)(expr.exprId) + } else { + newExpr + } + } + } + /** * Return all references that are presented in the join conditions but not in the output * of the given named expressions. @@ -429,8 +446,9 @@ object DecorrelateInnerQuery extends PredicateHelper { val newOuterReferences = parentOuterReferences ++ outerReferences val (newChild, joinCond, outerReferenceMap) = decorrelate(child, newOuterReferences, aggregated) - // Replace all outer references in the original project list. - val newProjectList = replaceOuterReferences(projectList, outerReferenceMap) + // Replace all outer references in the original project list and keep the output + // attributes unchanged. + val newProjectList = replaceOuterInNamedExpressions(projectList, outerReferenceMap) // Preserve required domain attributes in the join condition by adding the missing // references to the new project list. val referencesToAdd = missingReferences(newProjectList, joinCond) @@ -442,9 +460,10 @@ object DecorrelateInnerQuery extends PredicateHelper { val newOuterReferences = parentOuterReferences ++ outerReferences val (newChild, joinCond, outerReferenceMap) = decorrelate(child, newOuterReferences, aggregated = true) - // Replace all outer references in grouping and aggregate expressions. + // Replace all outer references in grouping and aggregate expressions, and keep + // the output attributes unchanged. val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap) - val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap) + val newAggExpr = replaceOuterInNamedExpressions(aggregateExpressions, outerReferenceMap) // Add all required domain attributes to both grouping and aggregate expressions. val referencesToAdd = missingReferences(newAggExpr, joinCond) val newAggregate = a.copy( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c79fd7a87a83b..5b3c08969e8af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -179,6 +179,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // non-nullable when an empty relation child of a Union is removed UpdateAttributeNullability) :: Batch("Pullup Correlated Expressions", Once, + OptimizeOneRowRelationSubquery, PullupCorrelatedPredicates) :: // Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense // to enforce idempotence on it and we change this batch from Once to FixedPoint(1). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 53448fbe92d4c..44a07196a584f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ @@ -711,3 +712,43 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] { Join(left, newRight, joinType, newCond, JoinHint.NONE) } } + +/** + * This rule optimizes subqueries with OneRowRelation as leaf nodes. + */ +object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { + + object OneRowSubquery { + def unapply(plan: LogicalPlan): Option[Seq[NamedExpression]] = { + CollapseProject(EliminateSubqueryAliases(plan)) match { + case Project(projectList, _: OneRowRelation) => Some(stripOuterReferences(projectList)) + case _ => None + } + } + } + + /** + * Rewrite a subquery expression into one or more expressions. The rewrite can only be done + * if there is no nested subqueries in the subquery plan. + */ + private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries { + case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None) + if right.plan.subqueriesAll.isEmpty => + Project(left.output ++ projectList, left) + case p: LogicalPlan => p.transformExpressionsUpWithPruning( + _.containsPattern(SCALAR_SUBQUERY)) { + case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _) + if s.plan.subqueriesAll.isEmpty => + assert(projectList.size == 1) + projectList.head + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.optimizeOneRowRelationSubquery) { + plan + } else { + rewrite(plan) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index db7fd5c3a1079..563f6451cd5a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -435,6 +435,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] subqueries ++ subqueries.flatMap(_.subqueriesAll) } + /** + * Returns a copy of this node where the given partial function has been recursively applied + * first to this node's children, then this node's subqueries, and finally this node itself + * (post-order). When the partial function does not apply to a given node, it is left unchanged. + */ + def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { + override def isDefinedAt(x: PlanType): Boolean = true + + override def apply(plan: PlanType): PlanType = { + val transformed = plan transformExpressionsUp { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformUpWithSubqueries(f) + planExpression.withNewPlan(newPlan) + } + f.applyOrElse[PlanType, PlanType](transformed, identity) + } + } + + transformUp(g) + } + /** * A variant of `collect`. This method not only apply the given function to all elements in this * plan, also considering all the plans in its (nested) subqueries diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e9c5f6e9bef04..e399ca73c3dc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2593,6 +2593,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY = + buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery") + .internal() + .doc("When true, the optimizer will inline subqueries with OneRowRelation as leaf nodes.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + val TOP_K_SORT_FALLBACK_THRESHOLD = buildConf("spark.sql.execution.topKSortFallbackThreshold") .internal() @@ -4053,6 +4061,8 @@ class SQLConf extends Serializable with Logging { def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED) + def optimizeOneRowRelationSubquery: Boolean = getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY) + def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index e06af08147a21..f5ee43f1c3dee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1838,7 +1838,8 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("Subquery reuse across the whole plan") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "false") { val df = sql( """ |SELECT (SELECT avg(key) FROM testData), (SELECT (SELECT avg(key) FROM testData)) @@ -1876,4 +1877,54 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark "ReusedSubqueryExec should reuse an existing subquery") } } + + test("SPARK-36063: optimize one row relation subqueries") { + withTempView("t") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + Seq( + "select (select c1) from t", + "select (select a) from t as t(a, b)", + "select (select c from (select c from (select c1 as c))) from t", + "select (select (select a) from (select c1, c2) t(a, b)) from t", + "select s.c1 from t, lateral (select c1, c2) s" + ).foreach { query => + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> enabled.toString) { + val df = sql(query) + val plan = df.queryExecution.optimizedPlan + val joins = plan.collectWithSubqueries { case j: Join => j } + assert(joins.isEmpty == enabled) + checkAnswer(df, Row(0) :: Row(1) :: Nil) + } + } + } + } + } + + test("SPARK-36063: optimize one row relation subqueries (negative case)") { + withTempView("t") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "true") { + Seq( + // With additional operators + ("select (select c1 where c2 = 1) from t", Row(0) :: Row(null) :: Nil), + // With non-deterministic expressions + ("select (select floor(r) from (select c1 + rand() as r)) from t", + Row(0) :: Row(1) :: Nil), + // With non-empty lateral join condition + ("select * from t join lateral (select c1, c2) s on t.c1 = s.c2", Nil), + // With nested subqueries that cannot be optimized + ("select (select (select a where a = 1) from (select c1 as a)) from t", + Row(null) :: Row(1) :: Nil), + ("select * from t, lateral (select (select a where a = 1) from (select c1 as a))", + Row(0, 1, null) :: Row(1, 2, 1) :: Nil) + ).foreach { case (query, expected) => + val df = sql(query) + val joins = df.queryExecution.optimizedPlan.collect { case j: Join => j } + assert(joins.nonEmpty) + checkAnswer(df, expected) + } + } + } + } } From 33a0e9a9470a4c97aeb4a701262a41e12ef2b960 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 9 Jul 2021 17:04:02 -0700 Subject: [PATCH 2/5] fix tests --- .../sql/catalyst/optimizer/subquery.scala | 4 ++-- .../DecorrelateInnerQuerySuite.scala | 20 ++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 44a07196a584f..1a3be18483912 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -733,12 +733,12 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { */ private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries { case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None) - if right.plan.subqueriesAll.isEmpty => + if right.plan.subqueriesAll.isEmpty && right.joinCond.isEmpty => Project(left.output ++ projectList, left) case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _) - if s.plan.subqueriesAll.isEmpty => + if s.plan.subqueriesAll.isEmpty && s.joinCond.isEmpty => assert(projectList.size == 1) projectList.head } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala index 92995c2e85eda..b8886a5c0b2fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala @@ -32,6 +32,7 @@ class DecorrelateInnerQuerySuite extends PlanTest { val x = AttributeReference("x", IntegerType)() val y = AttributeReference("y", IntegerType)() val z = AttributeReference("z", IntegerType)() + val t0 = OneRowRelation() val testRelation = LocalRelation(a, b, c) val testRelation2 = LocalRelation(x, y, z) @@ -203,23 +204,24 @@ class DecorrelateInnerQuerySuite extends PlanTest { test("correlated values in project") { val outerPlan = testRelation2 - val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation()) - val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation())) + val innerPlan = Project(Seq(OuterReference(x).as("x1"), OuterReference(y).as("y1")), t0) + val correctAnswer = Project( + Seq(x.as("x1"), y.as("y1"), x, y), DomainJoin(Seq(x, y), t0)) check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) } test("correlated values in project with alias") { val outerPlan = testRelation2 val innerPlan = - Project(Seq(OuterReference(x), 'y1, 'sum), + Project(Seq(OuterReference(x).as("x1"), 'y1, 'sum), Project(Seq( OuterReference(x), OuterReference(y).as("y1"), Add(OuterReference(x), OuterReference(y)).as("sum")), testRelation)).analyze val correctAnswer = - Project(Seq(x, 'y1, 'sum, y), - Project(Seq(x, y.as("y1"), (x + y).as("sum"), y), + Project(Seq(x.as("x1"), 'y1, 'sum, x, y), + Project(Seq(x.as(x.name), y.as("y1"), (x + y).as("sum"), x, y), DomainJoin(Seq(x, y), testRelation))).analyze check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) } @@ -228,13 +230,13 @@ class DecorrelateInnerQuerySuite extends PlanTest { val outerPlan = testRelation2 val innerPlan = Project( - Seq(OuterReference(x)), + Seq(OuterReference(x).as("x1")), Filter( And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), testRelation ) ) - val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation)) + val correctAnswer = Project(Seq(a.as("x1"), a, c), Filter(b === 1, testRelation)) check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c)) } @@ -242,14 +244,14 @@ class DecorrelateInnerQuerySuite extends PlanTest { val outerPlan = testRelation2 val innerPlan = Project( - Seq(OuterReference(y)), + Seq(OuterReference(y).as("y1")), Filter( And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), testRelation ) ) val correctAnswer = - Project(Seq(y, a, c), + Project(Seq(y.as("y1"), y, a, c), Filter(b === 1, DomainJoin(Seq(y), testRelation) ) From 73afab10b299be3072553db90a2bb4125ae63a66 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 14 Jul 2021 18:03:50 -0700 Subject: [PATCH 3/5] address comments --- .../sql/catalyst/optimizer/DecorrelateInnerQuery.scala | 2 +- .../org/apache/spark/sql/catalyst/optimizer/subquery.scala | 6 +++--- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 -- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index f060f6c0b4b60..71f3897ccf50b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -162,7 +162,7 @@ object DecorrelateInnerQuery extends PredicateHelper { */ private def replaceOuterInNamedExpressions( expressions: Seq[NamedExpression], - outerReferenceMap: Map[Attribute, Attribute]): Seq[NamedExpression] = { + outerReferenceMap: AttributeMap[Attribute]): Seq[NamedExpression] = { expressions.map { expr => val newExpr = replaceOuterReference(expr, outerReferenceMap) if (!newExpr.toAttribute.semanticEquals(expr.toAttribute)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 1a3be18483912..a08f200ad35d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -733,19 +733,19 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { */ private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries { case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None) - if right.plan.subqueriesAll.isEmpty && right.joinCond.isEmpty => + if right.plan.subqueries.isEmpty && right.joinCond.isEmpty => Project(left.output ++ projectList, left) case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _) - if s.plan.subqueriesAll.isEmpty && s.joinCond.isEmpty => + if s.plan.subqueries.isEmpty && s.joinCond.isEmpty => assert(projectList.size == 1) projectList.head } } def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.optimizeOneRowRelationSubquery) { + if (!conf.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY)) { plan } else { rewrite(plan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e399ca73c3dc2..73a7063fe4772 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4061,8 +4061,6 @@ class SQLConf extends Serializable with Logging { def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED) - def optimizeOneRowRelationSubquery: Boolean = getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY) - def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) From c148cfee696e10ea9ca3029bc4eb1f1435c57c69 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Thu, 15 Jul 2021 16:45:38 -0700 Subject: [PATCH 4/5] update test cases --- .../spark/sql/catalyst/dsl/package.scala | 7 + .../sql/catalyst/optimizer/subquery.scala | 8 +- .../OptimizeOneRowRelationSubquerySuite.scala | 161 ++++++++++++++++++ .../org/apache/spark/sql/SubquerySuite.scala | 50 ------ 4 files changed, 174 insertions(+), 52 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d0f63ba7412f7..7fb0e32645dcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -390,6 +390,13 @@ package object dsl { condition: Option[Expression] = None): LogicalPlan = Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE) + def lateralJoin( + otherPlan: LogicalPlan, + joinType: JoinType = Inner, + condition: Option[Expression] = None): LogicalPlan = { + LateralJoin(logicalPlan, LateralSubquery(otherPlan), joinType, condition) + } + def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder]( otherPlan: LogicalPlan, func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index a08f200ad35d6..6d6b8b7d8aca8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -727,18 +727,22 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { } } + private def hasCorrelatedSubquery(plan: LogicalPlan): Boolean = { + plan.find(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)).isDefined + } + /** * Rewrite a subquery expression into one or more expressions. The rewrite can only be done * if there is no nested subqueries in the subquery plan. */ private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries { case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None) - if right.plan.subqueries.isEmpty && right.joinCond.isEmpty => + if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty => Project(left.output ++ projectList, left) case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _) - if s.plan.subqueries.isEmpty && s.joinCond.isEmpty => + if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => assert(projectList.size == 1) projectList.head } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala new file mode 100644 index 0000000000000..686df035d4c66 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation, LogicalPlan, OneRowRelation, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf + +class OptimizeOneRowRelationSubquerySuite extends PlanTest { + + private var optimizeOneRowRelationSubqueryEnabled: Boolean = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + optimizeOneRowRelationSubqueryEnabled = + SQLConf.get.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY) + SQLConf.get.setConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY, true) + } + + protected override def afterAll(): Unit = { + SQLConf.get.setConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY, + optimizeOneRowRelationSubqueryEnabled) + super.afterAll() + } + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subquery", Once, + OptimizeOneRowRelationSubquery, + PullupCorrelatedPredicates) :: Nil + } + + private def assertHasDomainJoin(plan: LogicalPlan): Unit = { + assert(plan.collectWithSubqueries { case d: DomainJoin => d }.nonEmpty, + s"Plan does not contain DomainJoin:\n$plan") + } + + val t0 = OneRowRelation() + val a = 'a.int + val b = 'b.int + val t1 = LocalRelation(a, b) + val t2 = LocalRelation('c.int, 'd.int) + + test("Optimize scalar subquery with a single project") { + // SELECT (SELECT a) FROM t1 + val query = t1.select(ScalarSubquery(t0.select('a)).as("sub")) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select('a.as("sub")) + comparePlans(optimized, correctAnswer.analyze) + } + + test("Optimize lateral subquery with a single project") { + Seq(Inner, LeftOuter, Cross).foreach { joinType => + // SELECT * FROM t1 JOIN LATERAL (SELECT a, b) + val query = t1.lateralJoin(t0.select('a, 'b), joinType, None) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select('a, 'b, 'a.as("a"), 'b.as("b")) + comparePlans(optimized, correctAnswer.analyze) + } + } + + test("Optimize subquery with subquery alias") { + val inner = t0.select('a).as("t2") + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select('a.as("sub")) + comparePlans(optimized, correctAnswer.analyze) + } + + test("Optimize scalar subquery with multiple projects") { + // SELECT (SELECT a1 + b1 FROM (SELECT a AS a1, b AS b1)) FROM t1 + val inner = t0.select('a.as("a1"), 'b.as("b1")).select(('a1 + 'b1).as("c")) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = Project(Alias(Alias(a + b, "c")(), "sub")() :: Nil, t1) + comparePlans(optimized, correctAnswer) + } + + test("Optimize lateral subquery with multiple projects") { + Seq(Inner, LeftOuter, Cross).foreach { joinType => + val inner = t0.select('a.as("a1"), 'b.as("b1")) + .select(('a1 + 'b1).as("c1"), ('a1 - 'b1).as("c2")) + val query = t1.lateralJoin(inner, joinType, None) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = t1.select('a, 'b, ('a + 'b).as("c1"), ('a - 'b).as("c2")) + comparePlans(optimized, correctAnswer.analyze) + } + } + + test("Optimize subquery with nested correlated subqueries") { + // SELECT (SELECT (SELECT b) FROM (SELECT a AS b)) FROM t1 + val inner = t0.select('a.as("b")).select(ScalarSubquery(t0.select('b)).as("s")) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = Project(Alias(Alias(a, "s")(), "sub")() :: Nil, t1) + comparePlans(optimized, correctAnswer) + } + + test("Batch should be idempotent") { + // SELECT (SELECT 1 WHERE a = a + 1) FROM t1 + val inner = t0.select(1).where('a === 'a + 1) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + val doubleOptimized = Optimize.execute(optimized) + comparePlans(optimized, doubleOptimized, checkAnalysis = false) + } + + test("Should not optimize scalar subquery with operators other than project") { + // SELECT (SELECT a AS a1 WHERE a = 1) FROM t1 + val inner = t0.where('a === 1).select('a.as("a1")) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + assertHasDomainJoin(optimized) + } + + test("Should not optimize subquery with non-deterministic expressions") { + // SELECT (SELECT r FROM (SELECT a + rand() AS r)) FROM t1 + val inner = t0.select(('a + rand(0)).as("r")).select('r) + val query = t1.select(ScalarSubquery(inner).as("sub")) + val optimized = Optimize.execute(query.analyze) + assertHasDomainJoin(optimized) + } + + test("Should not optimize lateral join with non-empty join conditions") { + Seq(Inner, LeftOuter).foreach { joinType => + // SELECT * FROM t1 JOIN LATERAL (SELECT a AS a1, b AS b1) ON a = b1 + val query = t1.lateralJoin(t0.select('a.as("a1"), 'b.as("b1")), joinType, Some('a === 'b1)) + val optimized = Optimize.execute(query.analyze) + assertHasDomainJoin(optimized) + } + } + + test("Should not optimize subquery with nested subqueries") { + // SELECT (SELECT (SELECT a WHERE a = 1) FROM (SELECT a AS a)) FROM t1 + val inner = t0.select('a).where('a === 1) + val subquery = t0.select('a.as("a")) + .select(ScalarSubquery(inner).as("s")).select('s + 1) + val query = t1.select(ScalarSubquery(subquery).as("sub")) + val optimized = Optimize.execute(query.analyze) + assertHasDomainJoin(optimized) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index f5ee43f1c3dee..c3362b377e152 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1877,54 +1877,4 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark "ReusedSubqueryExec should reuse an existing subquery") } } - - test("SPARK-36063: optimize one row relation subqueries") { - withTempView("t") { - Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") - Seq( - "select (select c1) from t", - "select (select a) from t as t(a, b)", - "select (select c from (select c from (select c1 as c))) from t", - "select (select (select a) from (select c1, c2) t(a, b)) from t", - "select s.c1 from t, lateral (select c1, c2) s" - ).foreach { query => - Seq(true, false).foreach { enabled => - withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> enabled.toString) { - val df = sql(query) - val plan = df.queryExecution.optimizedPlan - val joins = plan.collectWithSubqueries { case j: Join => j } - assert(joins.isEmpty == enabled) - checkAnswer(df, Row(0) :: Row(1) :: Nil) - } - } - } - } - } - - test("SPARK-36063: optimize one row relation subqueries (negative case)") { - withTempView("t") { - Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") - withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "true") { - Seq( - // With additional operators - ("select (select c1 where c2 = 1) from t", Row(0) :: Row(null) :: Nil), - // With non-deterministic expressions - ("select (select floor(r) from (select c1 + rand() as r)) from t", - Row(0) :: Row(1) :: Nil), - // With non-empty lateral join condition - ("select * from t join lateral (select c1, c2) s on t.c1 = s.c2", Nil), - // With nested subqueries that cannot be optimized - ("select (select (select a where a = 1) from (select c1 as a)) from t", - Row(null) :: Row(1) :: Nil), - ("select * from t, lateral (select (select a where a = 1) from (select c1 as a))", - Row(0, 1, null) :: Row(1, 2, 1) :: Nil) - ).foreach { case (query, expected) => - val df = sql(query) - val joins = df.queryExecution.optimizedPlan.collect { case j: Join => j } - assert(joins.nonEmpty) - checkAnswer(df, expected) - } - } - } - } } From 7ba19740668cc8d0f28a83482344b7828ccaafff Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Mon, 19 Jul 2021 14:05:47 -0700 Subject: [PATCH 5/5] address comments --- .../spark/sql/catalyst/plans/QueryPlan.scala | 23 ++++++++----------- .../OptimizeOneRowRelationSubquerySuite.scala | 20 +++++++++------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 563f6451cd5a6..3c9946ba3772b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -437,24 +437,19 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] /** * Returns a copy of this node where the given partial function has been recursively applied - * first to this node's children, then this node's subqueries, and finally this node itself - * (post-order). When the partial function does not apply to a given node, it is left unchanged. + * first to the subqueries in this node's children, then this node's children, and finally + * this node itself (post-order). When the partial function does not apply to a given node, + * it is left unchanged. */ def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { - val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { - override def isDefinedAt(x: PlanType): Boolean = true - - override def apply(plan: PlanType): PlanType = { - val transformed = plan transformExpressionsUp { - case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformUpWithSubqueries(f) - planExpression.withNewPlan(newPlan) - } - f.applyOrElse[PlanType, PlanType](transformed, identity) + transformUp { case plan => + val transformed = plan transformExpressionsUp { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformUpWithSubqueries(f) + planExpression.withNewPlan(newPlan) } + f.applyOrElse[PlanType, PlanType](transformed, identity) } - - transformUp(g) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala index 686df035d4c66..4203859226fae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.CleanupAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation, LogicalPlan, OneRowRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation, LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf @@ -46,7 +47,9 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest { val batches = Batch("Subquery", Once, OptimizeOneRowRelationSubquery, - PullupCorrelatedPredicates) :: Nil + PullupCorrelatedPredicates) :: + Batch("Cleanup", FixedPoint(10), + CleanupAliases) :: Nil } private def assertHasDomainJoin(plan: LogicalPlan): Unit = { @@ -91,8 +94,8 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest { val inner = t0.select('a.as("a1"), 'b.as("b1")).select(('a1 + 'b1).as("c")) val query = t1.select(ScalarSubquery(inner).as("sub")) val optimized = Optimize.execute(query.analyze) - val correctAnswer = Project(Alias(Alias(a + b, "c")(), "sub")() :: Nil, t1) - comparePlans(optimized, correctAnswer) + val correctAnswer = t1.select(('a + 'b).as("c").as("sub")) + comparePlans(optimized, correctAnswer.analyze) } test("Optimize lateral subquery with multiple projects") { @@ -111,8 +114,8 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest { val inner = t0.select('a.as("b")).select(ScalarSubquery(t0.select('b)).as("s")) val query = t1.select(ScalarSubquery(inner).as("sub")) val optimized = Optimize.execute(query.analyze) - val correctAnswer = Project(Alias(Alias(a, "s")(), "sub")() :: Nil, t1) - comparePlans(optimized, correctAnswer) + val correctAnswer = t1.select('a.as("s").as("sub")) + comparePlans(optimized, correctAnswer.analyze) } test("Batch should be idempotent") { @@ -149,8 +152,9 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest { } } - test("Should not optimize subquery with nested subqueries") { + test("Should not optimize subquery with nested subqueries that can't be optimized") { // SELECT (SELECT (SELECT a WHERE a = 1) FROM (SELECT a AS a)) FROM t1 + // Filter (a = 1) cannot be optimized. val inner = t0.select('a).where('a === 1) val subquery = t0.select('a.as("a")) .select(ScalarSubquery(inner).as("s")).select('s + 1)