From 46a227eddd315ea74e29b3d4a7eca5c0f0ab258f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 18 Sep 2020 11:40:54 +0900 Subject: [PATCH 1/4] Fix --- .../sql/catalyst/optimizer/subquery.scala | 82 ++++++++++++------- 1 file changed, 53 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index a168dcd7a83f5..837eab84e0881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -338,20 +338,15 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { /** * Extract all correlated scalar subqueries from an expression. The subqueries are collected using - * the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId` - * for the subqueries and rewrite references in the given `expression`. - * This method returns extracted subqueries and the corresponding `exprId`s and these values - * will be used later in `constructLeftJoins` for building the child plan that - * returns subquery output with the `exprId`s. + * the given collector. The expression is rewritten and returned. */ private def extractCorrelatedScalarSubqueries[E <: Expression]( expression: E, - subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = { + subqueries: ArrayBuffer[ScalarSubquery]): E = { val newExpression = expression transform { case s: ScalarSubquery if s.children.nonEmpty => - val newExprId = NamedExpression.newExprId - subqueries += s -> newExprId - s.plan.output.head.withExprId(newExprId) + subqueries += s + s.plan.output.head } newExpression.asInstanceOf[E] } @@ -512,19 +507,23 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { /** * Construct a new child plan by left joining the given subqueries to a base plan. + * This method returns the child plan and an attribute mapping + * for the updated `ExprId`s of subqueries. If the non-empty mapping returned, + * this rule will rewrite subquery references in a parent plan based on it. */ private def constructLeftJoins( child: LogicalPlan, - subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = { - subqueries.foldLeft(child) { - case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) => + subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = { + val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() + val newChild = subqueries.foldLeft(child) { + case (currentChild, ScalarSubquery(query, conditions, _)) => val origOutput = query.output.head val resultWithZeroTups = evalSubqueryOnZeroTups(query) if (resultWithZeroTups.isEmpty) { // CASE 1: Subquery guaranteed not to have the COUNT bug Project( - currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId), + currentChild.output :+ origOutput, Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } else { // Subquery might have the COUNT bug. Add appropriate corrections. @@ -544,12 +543,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { if (havingNode.isEmpty) { // CASE 2: Subquery with no HAVING clause + val subqueryResultExpr = + Alias(If(IsNull(alwaysTrueRef), + resultWithZeroTups.get, + aggValRef), origOutput.name)() + subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute)) Project( - currentChild.output :+ - Alias( - If(IsNull(alwaysTrueRef), - resultWithZeroTups.get, - aggValRef), origOutput.name)(exprId = newExprId), + currentChild.output :+ subqueryResultExpr, Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) @@ -576,7 +576,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { (IsNull(alwaysTrueRef), resultWithZeroTups.get), (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), aggValRef), - origOutput.name)(exprId = newExprId) + origOutput.name)() + + subqueryAttrMapping += ((origOutput, caseExpr.toAttribute)) Project( currentChild.output :+ caseExpr, @@ -587,6 +589,22 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { } } } + (newChild, AttributeMap(subqueryAttrMapping)) + } + + private def updateAttrs[E <: Expression]( + exprs: Seq[E], + attrMap: AttributeMap[Attribute]): Seq[E] = { + if (attrMap.nonEmpty) { + val newExprs = exprs.map { _.transform { + case a: AttributeReference if attrMap.contains(a) => + val exprId = attrMap.getOrElse(a, a).exprId + a.withExprId(exprId) + }} + newExprs.asInstanceOf[Seq[E]] + } else { + exprs + } } /** @@ -595,36 +613,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { */ def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { case a @ Aggregate(grouping, expressions, child) => - val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] - val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) if (subqueries.nonEmpty) { // We currently only allow correlated subqueries in an aggregate if they are part of the // grouping expressions. As a result we need to replace all the scalar subqueries in the // grouping expressions by their result. val newGrouping = grouping.map { e => - subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e) + subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e) } - val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) + val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries) + val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping) + val newAgg = Aggregate(newGrouping, newExprs, newChild) val attrMapping = a.output.zip(newAgg.output) newAgg -> attrMapping } else { a -> Nil } case p @ Project(expressions, child) => - val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] - val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) if (subqueries.nonEmpty) { - val newProj = Project(newExpressions, constructLeftJoins(child, subqueries)) + val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries) + val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping) + val newProj = Project(newExprs, newChild) val attrMapping = p.output.zip(newProj.output) newProj -> attrMapping } else { p -> Nil } case f @ Filter(condition, child) => - val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] - val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val rewriteCondition = extractCorrelatedScalarSubqueries(condition, subqueries) if (subqueries.nonEmpty) { - val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) + val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries) + val newCondition = updateAttrs(Seq(rewriteCondition), subqueryAttrMapping).head + val newProj = Project(f.output, Filter(newCondition, newChild)) val attrMapping = f.output.zip(newProj.output) newProj -> attrMapping } else { From 405be2dc38f87927c0b547be20c865016f0b9bc6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 7 Oct 2020 08:55:57 +0900 Subject: [PATCH 2/4] Fix --- .../org/apache/spark/sql/catalyst/optimizer/subquery.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 837eab84e0881..7ba39d84bb65f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -597,9 +597,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { attrMap: AttributeMap[Attribute]): Seq[E] = { if (attrMap.nonEmpty) { val newExprs = exprs.map { _.transform { - case a: AttributeReference if attrMap.contains(a) => - val exprId = attrMap.getOrElse(a, a).exprId - a.withExprId(exprId) + case a: AttributeReference if attrMap.contains(a) => attrMap(a) }} newExprs.asInstanceOf[Seq[E]] } else { From 47ea616585adf31b8513302e5baf9c7eb2f88d26 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 7 Oct 2020 14:09:47 +0900 Subject: [PATCH 3/4] Fix --- .../org/apache/spark/sql/catalyst/optimizer/subquery.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 7ba39d84bb65f..b3c9a19a68714 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -589,7 +589,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { } } } - (newChild, AttributeMap(subqueryAttrMapping)) + (newChild, AttributeMap(subqueryAttrMapping.toSeq)) } private def updateAttrs[E <: Expression]( From cbd0c5cf72100313d076e6cb9e0e827d96fcc645 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 7 Oct 2020 16:49:26 +0900 Subject: [PATCH 4/4] Fix --- .../org/apache/spark/sql/catalyst/optimizer/subquery.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index b3c9a19a68714..f184253ef0595 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -597,7 +597,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { attrMap: AttributeMap[Attribute]): Seq[E] = { if (attrMap.nonEmpty) { val newExprs = exprs.map { _.transform { - case a: AttributeReference if attrMap.contains(a) => attrMap(a) + case a: AttributeReference => attrMap.getOrElse(a, a) }} newExprs.asInstanceOf[Seq[E]] } else {