diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0e81f48fc7ebb..7c6d0fcd9c8c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -123,6 +123,127 @@ object AnalysisContext { } } +object Analyzer { + + /** + * Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones. + * This method also updates all the related references in the `plan` accordingly. + * + * @param plan to rewrite + * @param rewritePlanMap has mappings from old plans to new ones for the given `plan`. + * @return a rewritten plan and updated references related to a root node of + * the given `plan` for rewriting it. + */ + def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan]) + : (LogicalPlan, Seq[(Attribute, Attribute)]) = { + if (plan.resolved) { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + val newChildren = plan.children.map { child => + // If not, we'd rewrite child plan recursively until we find the + // conflict node or reach the leaf node. + val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap) + attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => + // `attrMapping` is not only used to replace the attributes of the current `plan`, + // but also to be propagated to the parent plans of the current `plan`. Therefore, + // the `oldAttr` must be part of either `plan.references` (so that it can be used to + // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be + // used by those parent plans). + (plan.outputSet ++ plan.references).contains(oldAttr) + } + newChild + } + + val newPlan = if (rewritePlanMap.contains(plan)) { + rewritePlanMap(plan).withNewChildren(newChildren) + } else { + plan.withNewChildren(newChildren) + } + + assert(!attrMapping.groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + + val attributeRewrites = AttributeMap(attrMapping) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + val p = newPlan.transformExpressions { + case a: Attribute => + updateAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites)) + } + attrMapping ++= plan.output.zip(p.output) + .filter { case (a1, a2) => a1.exprId != a2.exprId } + p -> attrMapping + } else { + // Just passes through unresolved nodes + plan.mapChildren { + rewritePlan(_, rewritePlanMap)._1 + } -> Nil + } + } + + private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + val exprId = attrMap.getOrElse(attr, attr).exprId + attr.withExprId(exprId) + } + + /** + * The outer plan may have old references and the function below updates the + * outer references to refer to the new attributes. + * + * For example (SQL): + * {{{ + * SELECT * FROM t1 + * INTERSECT + * SELECT * FROM t1 + * WHERE EXISTS (SELECT 1 + * FROM t2 + * WHERE t1.c1 = t2.c1) + * }}} + * Plan before resolveReference rule. + * 'Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- 'Project [*] + * +- Filter exists#257 [c1#245] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#245) = c1#251) + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#245,c2#246] parquet + * Plan after the resolveReference rule. + * Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- Project [c1#259, c2#260] + * +- Filter exists#257 [c1#259] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#259) = c1#251) => Updated + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten. + */ + private def updateOuterReferencesInSubquery( + plan: LogicalPlan, + attrMap: AttributeMap[Attribute]): LogicalPlan = { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan transformDown { case currentFragment => + currentFragment transformExpressions { + case OuterReference(a: Attribute) => + OuterReference(updateAttr(a, attrMap)) + case s: SubqueryExpression => + s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap)) + } + } + } + } +} + /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. @@ -1251,109 +1372,7 @@ class Analyzer( if (conflictPlans.isEmpty) { right } else { - rewritePlan(right, conflictPlans.toMap)._1 - } - } - - private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan]) - : (LogicalPlan, Seq[(Attribute, Attribute)]) = { - if (conflictPlanMap.contains(plan)) { - // If the plan is the one that conflict the with left one, we'd - // just replace it with the new plan and collect the rewrite - // attributes for the parent node. - val newRelation = conflictPlanMap(plan) - newRelation -> plan.output.zip(newRelation.output) - } else { - val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() - val newPlan = plan.mapChildren { child => - // If not, we'd rewrite child plan recursively until we find the - // conflict node or reach the leaf node. - val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap) - attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => - // `attrMapping` is not only used to replace the attributes of the current `plan`, - // but also to be propagated to the parent plans of the current `plan`. Therefore, - // the `oldAttr` must be part of either `plan.references` (so that it can be used to - // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be - // used by those parent plans). - (plan.outputSet ++ plan.references).contains(oldAttr) - } - newChild - } - - if (attrMapping.isEmpty) { - newPlan -> attrMapping.toSeq - } else { - assert(!attrMapping.groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") - val attributeRewrites = AttributeMap(attrMapping.toSeq) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - newPlan.transformExpressions { - case a: Attribute => - dedupAttr(a, attributeRewrites) - case s: SubqueryExpression => - s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) - } -> attrMapping.toSeq - } - } - } - - private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - val exprId = attrMap.getOrElse(attr, attr).exprId - attr.withExprId(exprId) - } - - /** - * The outer plan may have been de-duplicated and the function below updates the - * outer references to refer to the de-duplicated attributes. - * - * For example (SQL): - * {{{ - * SELECT * FROM t1 - * INTERSECT - * SELECT * FROM t1 - * WHERE EXISTS (SELECT 1 - * FROM t2 - * WHERE t1.c1 = t2.c1) - * }}} - * Plan before resolveReference rule. - * 'Intersect - * :- Project [c1#245, c2#246] - * : +- SubqueryAlias t1 - * : +- Relation[c1#245,c2#246] parquet - * +- 'Project [*] - * +- Filter exists#257 [c1#245] - * : +- Project [1 AS 1#258] - * : +- Filter (outer(c1#245) = c1#251) - * : +- SubqueryAlias t2 - * : +- Relation[c1#251,c2#252] parquet - * +- SubqueryAlias t1 - * +- Relation[c1#245,c2#246] parquet - * Plan after the resolveReference rule. - * Intersect - * :- Project [c1#245, c2#246] - * : +- SubqueryAlias t1 - * : +- Relation[c1#245,c2#246] parquet - * +- Project [c1#259, c2#260] - * +- Filter exists#257 [c1#259] - * : +- Project [1 AS 1#258] - * : +- Filter (outer(c1#259) = c1#251) => Updated - * : +- SubqueryAlias t2 - * : +- Relation[c1#251,c2#252] parquet - * +- SubqueryAlias t1 - * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated. - */ - private def dedupOuterReferencesInSubquery( - plan: LogicalPlan, - attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan transformDown { case currentFragment => - currentFragment transformExpressions { - case OuterReference(a: Attribute) => - OuterReference(dedupAttr(a, attrMap)) - case s: SubqueryExpression => - s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap)) - } + Analyzer.rewritePlan(right, conflictPlans.toMap)._1 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 604a082be4e55..861eddedc0e1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -326,29 +326,53 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - object WidenSetOperationTypes extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - case s @ Except(left, right, isAll) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) - assert(newChildren.length == 2) - Except(newChildren.head, newChildren.last, isAll) - - case s @ Intersect(left, right, isAll) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) - assert(newChildren.length == 2) - Intersect(newChildren.head, newChildren.last, isAll) - - case s: Union if s.childrenResolved && !s.byName && + object WidenSetOperationTypes extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]() + val newPlan = plan resolveOperatorsUp { + case s @ Except(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil) + if (newChildren.nonEmpty) { + rewritePlanMap ++= newChildren + Except(newChildren.head._1, newChildren.last._1, isAll) + } else { + s + } + + case s @ Intersect(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil) + if (newChildren.nonEmpty) { + rewritePlanMap ++= newChildren + Intersect(newChildren.head._1, newChildren.last._1, isAll) + } else { + s + } + + case s: Union if s.childrenResolved && !s.byName && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) - s.copy(children = newChildren) + val newChildren = buildNewChildrenWithWiderTypes(s.children) + if (newChildren.nonEmpty) { + rewritePlanMap ++= newChildren + s.copy(children = newChildren.map(_._1)) + } else { + s + } + } + + if (rewritePlanMap.nonEmpty) { + assert(!plan.fastEquals(newPlan)) + Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1 + } else { + plan + } } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]) + : Seq[(LogicalPlan, LogicalPlan)] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute @@ -360,8 +384,7 @@ object TypeCoercion { // Add an extra Project if the targetTypes are different from the original types. children.map(widenTypes(_, targetTypes)) } else { - // Unable to find a target type to widen, then just return the original set. - children + Nil } } @@ -385,12 +408,16 @@ object TypeCoercion { } /** Given a plan, add an extra project on top to widen some columns' data types. */ - private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]) + : (LogicalPlan, LogicalPlan) = { val casted = plan.output.zip(targetTypes).map { - case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() - case (e, _) => e - } - Project(casted, plan) + case (e, dt) if e.dataType != dt => + val alias = Alias(Cast(e, dt), e.name)(exprId = e.exprId) + alias -> alias.newInstance() + case (e, _) => + e -> e + }.unzip + Project(casted._1, plan) -> Project(casted._2, plan) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1ea1ddb8bbd08..1af562fd1a061 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -21,13 +21,12 @@ import java.sql.Timestamp import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval class TypeCoercionSuite extends AnalysisTest { import TypeCoercionSuite._ @@ -1417,6 +1416,20 @@ class TypeCoercionSuite extends AnalysisTest { } } + test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") { + val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))()) + val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))()) + val p1 = t1.select(t1.output.head) + val p2 = t2.select(t2.output.head) + val union = p1.union(p2) + val wp1 = widenSetOperationTypes(union.select(p1.output.head)) + assert(wp1.isInstanceOf[Project]) + assert(wp1.missingInput.isEmpty) + val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union)) + assert(wp2.isInstanceOf[Aggregate]) + assert(wp2.missingInput.isEmpty) + } + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. diff --git a/sql/core/src/test/resources/sql-tests/inputs/except.sql b/sql/core/src/test/resources/sql-tests/inputs/except.sql index 1d579e65f3473..ffdf1f4f3d24d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/except.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/except.sql @@ -55,3 +55,22 @@ FROM t1 WHERE t1.v >= (SELECT min(t2.v) FROM t2 WHERE t2.k = t1.k); + +-- SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes +CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v); +SELECT t.v FROM ( + SELECT v FROM t3 + EXCEPT + SELECT v + v AS v FROM t3 +) t; + +SELECT SUM(t.v) FROM ( + SELECT v FROM t3 + EXCEPT + SELECT v + v AS v FROM t3 +) t; + +-- Clean-up +DROP VIEW IF EXISTS t1; +DROP VIEW IF EXISTS t2; +DROP VIEW IF EXISTS t3; diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql index b0b2244048caa..077caa5dd44a0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql @@ -155,6 +155,21 @@ SELECT * FROM tab2; -- Restore the property SET spark.sql.legacy.setopsPrecedence.enabled = false; +-- SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes +CREATE OR REPLACE TEMPORARY VIEW tab3 AS VALUES (decimal(1)), (decimal(2)) tbl3(v); +SELECT t.v FROM ( + SELECT v FROM tab3 + INTERSECT + SELECT v + v AS v FROM tab3 +) t; + +SELECT SUM(t.v) FROM ( + SELECT v FROM tab3 + INTERSECT + SELECT v + v AS v FROM tab3 +) t; + -- Clean-up DROP VIEW IF EXISTS tab1; DROP VIEW IF EXISTS tab2; +DROP VIEW IF EXISTS tab3; diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index 6da1b9b49b226..8a5b6c50fc1e3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -45,10 +45,24 @@ SELECT array(1, 2), 'str' UNION ALL SELECT array(1, 2, 3, NULL), 1; +-- SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes +CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v); +SELECT t.v FROM ( + SELECT v FROM t3 + UNION ALL + SELECT v + v AS v FROM t3 +) t; + +SELECT SUM(t.v) FROM ( + SELECT v FROM t3 + UNION + SELECT v + v AS v FROM t3 +) t; -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; +DROP VIEW IF EXISTS t3; DROP VIEW IF EXISTS p1; DROP VIEW IF EXISTS p2; DROP VIEW IF EXISTS p3; diff --git a/sql/core/src/test/resources/sql-tests/results/except.sql.out b/sql/core/src/test/resources/sql-tests/results/except.sql.out index 62d695219d01d..061b122eac7cf 100644 --- a/sql/core/src/test/resources/sql-tests/results/except.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/except.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 15 -- !query @@ -103,3 +103,59 @@ WHERE t1.v >= (SELECT min(t2.v) struct -- !query output two + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT t.v FROM ( + SELECT v FROM t3 + EXCEPT + SELECT v + v AS v FROM t3 +) t +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT SUM(t.v) FROM ( + SELECT v FROM t3 + EXCEPT + SELECT v + v AS v FROM t3 +) t +-- !query schema +struct +-- !query output +1 + + +-- !query +DROP VIEW IF EXISTS t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS t2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS t3 +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out index 4762082dc3be2..b99f63393cc4d 100644 --- a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 22 +-- Number of queries: 26 -- !query @@ -291,6 +291,38 @@ struct spark.sql.legacy.setopsPrecedence.enabled false +-- !query +CREATE OR REPLACE TEMPORARY VIEW tab3 AS VALUES (decimal(1)), (decimal(2)) tbl3(v) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT t.v FROM ( + SELECT v FROM tab3 + INTERSECT + SELECT v + v AS v FROM tab3 +) t +-- !query schema +struct +-- !query output +2 + + +-- !query +SELECT SUM(t.v) FROM ( + SELECT v FROM tab3 + INTERSECT + SELECT v + v AS v FROM tab3 +) t +-- !query schema +struct +-- !query output +2 + + -- !query DROP VIEW IF EXISTS tab1 -- !query schema @@ -305,3 +337,11 @@ DROP VIEW IF EXISTS tab2 struct<> -- !query output + + +-- !query +DROP VIEW IF EXISTS tab3 +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index 44002406836a4..ce3c761bc5d2d 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 20 -- !query @@ -126,6 +126,39 @@ struct,str:string> [1,2] str +-- !query +CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT t.v FROM ( + SELECT v FROM t3 + UNION ALL + SELECT v + v AS v FROM t3 +) t +-- !query schema +struct +-- !query output +1 +2 + + +-- !query +SELECT SUM(t.v) FROM ( + SELECT v FROM t3 + UNION + SELECT v + v AS v FROM t3 +) t +-- !query schema +struct +-- !query output +3 + + -- !query DROP VIEW IF EXISTS t1 -- !query schema @@ -142,6 +175,14 @@ struct<> +-- !query +DROP VIEW IF EXISTS t3 +-- !query schema +struct<> +-- !query output + + + -- !query DROP VIEW IF EXISTS p1 -- !query schema