diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index c0ba3598e4ba1..976a5d385d874 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -69,13 +69,13 @@ object CTESubstitution extends Rule[LogicalPlan] { if (cteDefs.isEmpty) { substituted } else if (substituted eq lastSubstituted.get) { - WithCTE(substituted, cteDefs.toSeq) + WithCTE(substituted, cteDefs.sortBy(_.id).toSeq) } else { var done = false substituted.resolveOperatorsWithPruning(_ => !done) { case p if p eq lastSubstituted.get => done = true - WithCTE(p, cteDefs.toSeq) + WithCTE(p, cteDefs.sortBy(_.id).toSeq) } } } @@ -203,6 +203,7 @@ object CTESubstitution extends Rule[LogicalPlan] { cteDefs: mutable.ArrayBuffer[CTERelationDef]): Seq[(String, CTERelationDef)] = { val resolvedCTERelations = new mutable.ArrayBuffer[(String, CTERelationDef)](relations.size) for ((name, relation) <- relations) { + val lastCTEDefCount = cteDefs.length val innerCTEResolved = if (isLegacy) { // In legacy mode, outer CTE relations take precedence. Here we don't resolve the inner // `With` nodes, later we will substitute `UnresolvedRelation`s with outer CTE relations. @@ -211,8 +212,33 @@ object CTESubstitution extends Rule[LogicalPlan] { } else { // A CTE definition might contain an inner CTE that has a higher priority, so traverse and // substitute CTE defined in `relation` first. + // NOTE: we must call `traverseAndSubstituteCTE` before `substituteCTE`, as the relations + // in the inner CTE have higher priority over the relations in the outer CTE when resolving + // inner CTE relations. For example: + // WITH t1 AS (SELECT 1) + // t2 AS ( + // WITH t1 AS (SELECT 2) + // WITH t3 AS (SELECT * FROM t1) + // ) + // t3 should resolve the t1 to `SELECT 2` instead of `SELECT 1`. traverseAndSubstituteCTE(relation, isCommand, cteDefs)._1 } + + if (cteDefs.length > lastCTEDefCount) { + // We have added more CTE relations to the `cteDefs` from the inner CTE, and these relations + // should also be substituted with `resolvedCTERelations` as inner CTE relation can refer to + // outer CTE relation. For example: + // WITH t1 AS (SELECT 1) + // t2 AS ( + // WITH t3 AS (SELECT * FROM t1) + // ) + for (i <- lastCTEDefCount until cteDefs.length) { + val substituted = + substituteCTE(cteDefs(i).child, isLegacy || isCommand, resolvedCTERelations.toSeq) + cteDefs(i) = cteDefs(i).copy(child = substituted) + } + } + // CTE definition can reference a previous one val substituted = substituteCTE(innerCTEResolved, isLegacy || isCommand, resolvedCTERelations.toSeq) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7928aad1e6421..936dadc78c13e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, PercentileCont} -import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery} +import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery, InlineCTE} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -93,8 +93,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead - // of the result of cascading resolution failures. - plan.foreachUp { + // of the result of cascading resolution failures. Inline all CTEs in the plan to help check + // query plan structures in subqueries. + val inlineCTE = InlineCTE(alwaysInline = true) + inlineCTE(plan).foreachUp { case p if p.analyzed => // Skip already analyzed sub-plans diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 61577b1d21ea4..a740b92933fa4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -28,26 +28,37 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION} /** * Inlines CTE definitions into corresponding references if either of the conditions satisfies: - * 1. The CTE definition does not contain any non-deterministic expressions. If this CTE - * definition references another CTE definition that has non-deterministic expressions, it - * is still OK to inline the current CTE definition. + * 1. The CTE definition does not contain any non-deterministic expressions or contains attribute + * references to an outer query. If this CTE definition references another CTE definition that + * has non-deterministic expressions, it is still OK to inline the current CTE definition. * 2. The CTE definition is only referenced once throughout the main query and all the subqueries. * - * In addition, due to the complexity of correlated subqueries, all CTE references in correlated - * subqueries are inlined regardless of the conditions above. + * CTE definitions that appear in subqueries and are not inlined will be pulled up to the main + * query level. + * + * @param alwaysInline if true, inline all CTEs in the query plan. */ -object InlineCTE extends Rule[LogicalPlan] { +case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) { val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)] buildCTEMap(plan, cteMap) - inlineCTE(plan, cteMap, forceInline = false) + val notInlined = mutable.ArrayBuffer.empty[CTERelationDef] + val inlined = inlineCTE(plan, cteMap, notInlined) + // CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add + // WithCTE as top node here. + if (notInlined.isEmpty) { + inlined + } else { + WithCTE(inlined, notInlined.toSeq) + } } else { plan } } - private def shouldInline(cteDef: CTERelationDef, refCount: Int): Boolean = { + private def shouldInline(cteDef: CTERelationDef, refCount: Int): Boolean = alwaysInline || { // We do not need to check enclosed `CTERelationRef`s for `deterministic` or `OuterReference`, // because: // 1) It is fine to inline a CTE if it references another CTE that is non-deterministic; @@ -93,25 +104,24 @@ object InlineCTE extends Rule[LogicalPlan] { private def inlineCTE( plan: LogicalPlan, cteMap: mutable.HashMap[Long, (CTERelationDef, Int)], - forceInline: Boolean): LogicalPlan = { - val (stripped, notInlined) = plan match { + notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = { + plan match { case WithCTE(child, cteDefs) => - val notInlined = mutable.ArrayBuffer.empty[CTERelationDef] cteDefs.foreach { cteDef => val (cte, refCount) = cteMap(cteDef.id) if (refCount > 0) { - val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, forceInline)) + val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined)) cteMap.update(cteDef.id, (inlined, refCount)) - if (!forceInline && !shouldInline(inlined, refCount)) { + if (!shouldInline(inlined, refCount)) { notInlined.append(inlined) } } } - (inlineCTE(child, cteMap, forceInline), notInlined.toSeq) + inlineCTE(child, cteMap, notInlined) case ref: CTERelationRef => val (cteDef, refCount) = cteMap(ref.cteId) - val newRef = if (forceInline || shouldInline(cteDef, refCount)) { + if (shouldInline(cteDef, refCount)) { if (ref.outputSet == cteDef.outputSet) { cteDef.child } else { @@ -125,24 +135,16 @@ object InlineCTE extends Rule[LogicalPlan] { } else { ref } - (newRef, Seq.empty) case _ if plan.containsPattern(CTE) => - val newPlan = plan - .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, forceInline))) + plan + .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined))) .transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) { case e: SubqueryExpression => - e.withNewPlan(inlineCTE(e.plan, cteMap, forceInline = e.isCorrelated)) + e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined)) } - (newPlan, Seq.empty) - case _ => (plan, Seq.empty) - } - - if (notInlined.isEmpty) { - stripped - } else { - WithCTE(stripped, notInlined) + case _ => plan } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bb788336c6d77..cf2da22f6c22b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -128,7 +128,8 @@ abstract class Optimizer(catalogManager: CatalogManager) OptimizeUpdateFields, SimplifyExtractValueOps, OptimizeCsvJsonExprs, - CombineConcats) ++ + CombineConcats, + PushdownPredicatesAndPruneColumnsForCTEDef) ++ extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { @@ -147,22 +148,7 @@ abstract class Optimizer(catalogManager: CatalogManager) } val batches = ( - // Technically some of the rules in Finish Analysis are not optimizer rules and belong more - // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). - // However, because we also use the analyzer to canonicalized queries (for view definition), - // we do not eliminate subqueries or compute current time in the analyzer. - Batch("Finish Analysis", Once, - EliminateResolvedHint, - EliminateSubqueryAliases, - EliminateView, - InlineCTE, - ReplaceExpressions, - RewriteNonCorrelatedExists, - PullOutGroupingExpressions, - ComputeCurrentTime, - ReplaceCurrentLike(catalogManager), - SpecialDatetimeValues, - RewriteAsOfJoin) :: + Batch("Finish Analysis", Once, FinishAnalysis) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// @@ -172,6 +158,8 @@ abstract class Optimizer(catalogManager: CatalogManager) // extra operators between two adjacent Union operators. // - Call CombineUnions again in Batch("Operator Optimizations"), // since the other rules might make two separate Unions operators adjacent. + Batch("Inline CTE", Once, + InlineCTE()) :: Batch("Union", Once, RemoveNoopOperators, CombineUnions, @@ -208,6 +196,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: Nil ++ operatorOptimizationBatch) :+ + Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo) :+ // This batch rewrites plans after the operator optimization and // before any batches that depend on stats. Batch("Pre CBO Rules", Once, preCBORules: _*) :+ @@ -266,14 +255,7 @@ abstract class Optimizer(catalogManager: CatalogManager) * (defaultBatches - (excludedRules - nonExcludableRules)). */ def nonExcludableRules: Seq[String] = - EliminateDistinct.ruleName :: - EliminateResolvedHint.ruleName :: - EliminateSubqueryAliases.ruleName :: - EliminateView.ruleName :: - ReplaceExpressions.ruleName :: - ComputeCurrentTime.ruleName :: - SpecialDatetimeValues.ruleName :: - ReplaceCurrentLike(catalogManager).ruleName :: + FinishAnalysis.ruleName :: RewriteDistinctAggregates.ruleName :: ReplaceDeduplicateWithAggregate.ruleName :: ReplaceIntersectWithSemiJoin.ruleName :: @@ -287,10 +269,38 @@ abstract class Optimizer(catalogManager: CatalogManager) RewritePredicateSubquery.ruleName :: NormalizeFloatingNumbers.ruleName :: ReplaceUpdateFieldsExpression.ruleName :: - PullOutGroupingExpressions.ruleName :: - RewriteAsOfJoin.ruleName :: RewriteLateralSubquery.ruleName :: Nil + /** + * Apply finish-analysis rules for the entire plan including all subqueries. + */ + object FinishAnalysis extends Rule[LogicalPlan] { + // Technically some of the rules in Finish Analysis are not optimizer rules and belong more + // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). + // However, because we also use the analyzer to canonicalized queries (for view definition), + // we do not eliminate subqueries or compute current time in the analyzer. + private val rules = Seq( + EliminateResolvedHint, + EliminateSubqueryAliases, + EliminateView, + ReplaceExpressions, + RewriteNonCorrelatedExists, + PullOutGroupingExpressions, + ComputeCurrentTime, + ReplaceCurrentLike(catalogManager), + SpecialDatetimeValues, + RewriteAsOfJoin) + + override def apply(plan: LogicalPlan): LogicalPlan = { + rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + .transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { + case s: SubqueryExpression => + val Subquery(newPlan, _) = apply(Subquery.fromExpression(s)) + s.withNewPlan(newPlan) + } + } + } + /** * Optimize all the subqueries inside expression. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala new file mode 100644 index 0000000000000..ab9f20edb0bb9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, Literal, Or, SubqueryExpression} +import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.CTE + +/** + * Infer predicates and column pruning for [[CTERelationDef]] from its reference points, and push + * the disjunctive predicates as well as the union of attributes down the CTE plan. + */ +object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] { + + // CTE_id - (CTE_definition, precedence, predicates_to_push_down, attributes_to_prune) + private type CTEMap = mutable.HashMap[Long, (CTERelationDef, Int, Seq[Expression], AttributeSet)] + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) { + val cteMap = new CTEMap + gatherPredicatesAndAttributes(plan, cteMap) + pushdownPredicatesAndAttributes(plan, cteMap) + } else { + plan + } + } + + private def restoreCTEDefAttrs( + input: Seq[Expression], + mapping: Map[Attribute, Expression]): Seq[Expression] = { + input.map(e => e.transform { + case a: Attribute => + mapping.keys.find(_.semanticEquals(a)).map(mapping).getOrElse(a) + }) + } + + /** + * Gather all the predicates and referenced attributes on different points of CTE references + * using pattern `ScanOperation` (which takes care of determinism) and combine those predicates + * and attributes that belong to the same CTE definition. + * For the same CTE definition, if any of its references does not have predicates, the combined + * predicate will be a TRUE literal, which means there will be no predicate push-down. + */ + private def gatherPredicatesAndAttributes(plan: LogicalPlan, cteMap: CTEMap): Unit = { + plan match { + case WithCTE(child, cteDefs) => + cteDefs.zipWithIndex.foreach { case (cteDef, precedence) => + gatherPredicatesAndAttributes(cteDef.child, cteMap) + cteMap.put(cteDef.id, (cteDef, precedence, Seq.empty, AttributeSet.empty)) + } + gatherPredicatesAndAttributes(child, cteMap) + + case ScanOperation(projects, predicates, ref: CTERelationRef) => + val (cteDef, precedence, preds, attrs) = cteMap(ref.cteId) + val attrMapping = ref.output.zip(cteDef.output).map{ case (r, d) => r -> d }.toMap + val newPredicates = if (isTruePredicate(preds)) { + preds + } else { + // Make sure we only push down predicates that do not contain forward CTE references. + val filteredPredicates = restoreCTEDefAttrs(predicates.filter(_.find { + case s: SubqueryExpression => s.plan.find { + case r: CTERelationRef => + // If the ref's ID does not exist in the map or if ref's corresponding precedence + // is bigger than that of the current CTE we are pushing predicates for, it + // indicates a forward reference and we should exclude this predicate. + !cteMap.contains(r.cteId) || cteMap(r.cteId)._2 >= precedence + case _ => false + }.nonEmpty + case _ => false + }.isEmpty), attrMapping).filter(_.references.forall(cteDef.outputSet.contains)) + if (filteredPredicates.isEmpty) { + Seq(Literal.TrueLiteral) + } else { + preds :+ filteredPredicates.reduce(And) + } + } + val newAttributes = attrs ++ + AttributeSet(restoreCTEDefAttrs(projects.flatMap(_.references), attrMapping)) ++ + AttributeSet(restoreCTEDefAttrs(predicates.flatMap(_.references), attrMapping)) + + cteMap.update(ref.cteId, (cteDef, precedence, newPredicates, newAttributes)) + plan.subqueriesAll.foreach(s => gatherPredicatesAndAttributes(s, cteMap)) + + case _ => + plan.children.foreach(c => gatherPredicatesAndAttributes(c, cteMap)) + plan.subqueries.foreach(s => gatherPredicatesAndAttributes(s, cteMap)) + } + } + + /** + * Push down the combined predicate and attribute references to each CTE definition plan. + * + * In order to guarantee idempotency, we keep the predicates (if any) being pushed down by the + * last iteration of this rule in a temporary field of `CTERelationDef`, so that on the current + * iteration, we only push down predicates for a CTE def if there exists any new predicate that + * has not been pushed before. Also, since part of a new predicate might overlap with some + * existing predicate and it can be hard to extract only the non-overlapping part, we also keep + * the original CTE definition plan without any predicate push-down in that temporary field so + * that when we do a new predicate push-down, we can construct a new plan with all latest + * predicates over the original plan without having to figure out the exact predicate difference. + */ + private def pushdownPredicatesAndAttributes( + plan: LogicalPlan, + cteMap: CTEMap): LogicalPlan = plan.transformWithSubqueries { + case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates) => + val (_, _, newPreds, newAttrSet) = cteMap(id) + val originalPlan = originalPlanWithPredicates.map(_._1).getOrElse(child) + val preds = originalPlanWithPredicates.map(_._2).getOrElse(Seq.empty) + if (!isTruePredicate(newPreds) && + newPreds.exists(newPred => !preds.exists(_.semanticEquals(newPred)))) { + val newCombinedPred = newPreds.reduce(Or) + val newChild = if (needsPruning(originalPlan, newAttrSet)) { + Project(newAttrSet.toSeq, originalPlan) + } else { + originalPlan + } + CTERelationDef(Filter(newCombinedPred, newChild), id, Some((originalPlan, newPreds))) + } else if (needsPruning(cteDef.child, newAttrSet)) { + CTERelationDef(Project(newAttrSet.toSeq, cteDef.child), id, Some((originalPlan, preds))) + } else { + cteDef + } + + case cteRef @ CTERelationRef(cteId, _, output, _) => + val (cteDef, _, _, newAttrSet) = cteMap(cteId) + if (newAttrSet.size < output.size) { + val indices = newAttrSet.toSeq.map(cteDef.output.indexOf) + val newOutput = indices.map(output) + cteRef.copy(output = newOutput) + } else { + // Do not change the order of output columns if no column is pruned, in which case there + // might be no Project and the order is important. + cteRef + } + } + + private def isTruePredicate(predicates: Seq[Expression]): Boolean = { + predicates.length == 1 && predicates.head == Literal.TrueLiteral + } + + private def needsPruning(sourcePlan: LogicalPlan, attributeSet: AttributeSet): Boolean = { + attributeSet.size < sourcePlan.outputSet.size && attributeSet.subsetOf(sourcePlan.outputSet) + } +} + +/** + * Clean up temporary info from [[CTERelationDef]] nodes. This rule should be called after all + * iterations of [[PushdownPredicatesAndPruneColumnsForCTEDef]] are done. + */ +object CleanUpTempCTEInfo extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = + plan.transformWithPruning(_.containsPattern(CTE)) { + case cteDef @ CTERelationDef(_, _, Some(_)) => + cteDef.copy(originalPlanWithPredicates = None) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceCTERefWithRepartition.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceCTERefWithRepartition.scala new file mode 100644 index 0000000000000..e0d0417ce5161 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceCTERefWithRepartition.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.analysis.DeduplicateRelations +import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION} + +/** + * Replaces CTE references that have not been previously inlined with [[Repartition]] operations + * which will then be planned as shuffles and reused across different reference points. + * + * Note that this rule should be called at the very end of the optimization phase to best guarantee + * that CTE repartition shuffles are reused. + */ +object ReplaceCTERefWithRepartition extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case _: Subquery => plan + case _ => + replaceWithRepartition(plan, mutable.HashMap.empty[Long, LogicalPlan]) + } + + private def replaceWithRepartition( + plan: LogicalPlan, + cteMap: mutable.HashMap[Long, LogicalPlan]): LogicalPlan = plan match { + case WithCTE(child, cteDefs) => + cteDefs.foreach { cteDef => + val inlined = replaceWithRepartition(cteDef.child, cteMap) + val withRepartition = if (inlined.isInstanceOf[RepartitionOperation]) { + // If the CTE definition plan itself is a repartition operation, we do not need to add an + // extra repartition shuffle. + inlined + } else { + Repartition(conf.numShufflePartitions, shuffle = true, inlined) + } + cteMap.put(cteDef.id, withRepartition) + } + replaceWithRepartition(child, cteMap) + + case ref: CTERelationRef => + val cteDefPlan = cteMap(ref.cteId) + if (ref.outputSet == cteDefPlan.outputSet) { + cteDefPlan + } else { + val ctePlan = DeduplicateRelations( + Join(cteDefPlan, cteDefPlan, Inner, None, JoinHint(None, None))).children(1) + val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) => + Alias(srcAttr, tgtAttr.name)(exprId = tgtAttr.exprId) + } + Project(projectList, ctePlan) + } + + case _ if plan.containsPattern(CTE) => + plan + .withNewChildren(plan.children.map(c => replaceWithRepartition(c, cteMap))) + .transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) { + case e: SubqueryExpression => + e.withNewPlan(replaceWithRepartition(e.plan, cteMap)) + } + + case _ => plan + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 5d749b8fc4b53..0f8df5df3764a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -448,6 +448,14 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] subqueries ++ subqueries.flatMap(_.subqueriesAll) } + /** + * This method is similar to the transform method, but also applies the given partial function + * also to all the plans in the subqueries of a node. This method is useful when we want + * to rewrite the whole plan, include its subqueries, in one go. + */ + def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = + transformDownWithSubqueries(f) + /** * Returns a copy of this node where the given partial function has been recursively applied * first to the subqueries in this node's children, then this node's children, and finally @@ -465,6 +473,29 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } } + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. + * Returns a copy of this node where the given partial function has been recursively applied + * first to this node, then this node's subqueries and finally this node's children. + * When the partial function does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { + override def isDefinedAt(x: PlanType): Boolean = true + + override def apply(plan: PlanType): PlanType = { + val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) + transformed transformExpressionsDown { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformDownWithSubqueries(f) + planExpression.withNewPlan(newPlan) + } + } + } + + transformDown(g) + } + /** * A variant of `collect`. This method not only apply the given function to all elements in this * plan, also considering all the plans in its (nested) subqueries diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 895eeb772075d..e5eab691d14fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -659,8 +659,15 @@ case class UnresolvedWith( * A wrapper for CTE definition plan with a unique ID. * @param child The CTE definition query plan. * @param id The unique ID for this CTE definition. + * @param originalPlanWithPredicates The original query plan before predicate pushdown and the + * predicates that have been pushed down into `child`. This is + * a temporary field used by optimization rules for CTE predicate + * pushdown to help ensure rule idempotency. */ -case class CTERelationDef(child: LogicalPlan, id: Long = CTERelationDef.newId) extends UnaryNode { +case class CTERelationDef( + child: LogicalPlan, + id: Long = CTERelationDef.newId, + originalPlanWithPredicates: Option[(LogicalPlan, Seq[Expression])] = None) extends UnaryNode { final override val nodePatterns: Seq[TreePattern] = Seq(CTE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 804f1edbe06fd..7dde85014e7c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -108,7 +108,8 @@ trait AnalysisTest extends PlanTest { case v: View if v.isTempViewStoringAnalyzedPlan => v.child } val actualPlan = if (inlineCTE) { - InlineCTE(transformed) + val inlineCTE = InlineCTE() + inlineCTE(transformed) } else { transformed } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 4b74a96702c8b..9ea769b4cf153 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,8 +21,6 @@ import java.io.{BufferedWriter, OutputStreamWriter} import java.util.UUID import java.util.concurrent.atomic.AtomicLong -import scala.collection.mutable - import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging @@ -32,7 +30,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, CTERelationDef, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString @@ -64,17 +62,6 @@ class QueryExecution( // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner - // The CTE map for the planner shared by the main query and all subqueries. - private val cteMap = mutable.HashMap.empty[Long, CTERelationDef] - - def withCteMap[T](f: => T): T = { - val old = QueryExecution.currentCteMap.get() - QueryExecution.currentCteMap.set(cteMap) - try f finally { - QueryExecution.currentCteMap.set(old) - } - } - def assertAnalyzed(): Unit = analyzed def assertSupported(): Unit = { @@ -147,7 +134,7 @@ class QueryExecution( private def assertOptimized(): Unit = optimizedPlan - lazy val sparkPlan: SparkPlan = withCteMap { + lazy val sparkPlan: SparkPlan = { // We need to materialize the optimizedPlan here because sparkPlan is also tracked under // the planning phase assertOptimized() @@ -160,7 +147,7 @@ class QueryExecution( // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = withCteMap { + lazy val executedPlan: SparkPlan = { // We need to materialize the optimizedPlan here, before tracking the planning phase, to ensure // that the optimization time is not counted as part of the planning phase. assertOptimized() @@ -499,8 +486,4 @@ object QueryExecution { val preparationRules = preparations(session, Option(InsertAdaptiveSparkPlan(context)), true) prepareForExecution(preparationRules, sparkPlan.clone()) } - - private val currentCteMap = new ThreadLocal[mutable.HashMap[Long, CTERelationDef]]() - - def cteMap: mutable.HashMap[Long, CTERelationDef] = currentCteMap.get() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 8c134363af112..d9457a20d91c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -76,7 +76,8 @@ class SparkOptimizer( ColumnPruning, PushPredicateThroughNonJoin, RemoveNoopOperators) :+ - Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) :+ + Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition) override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+ ExtractPythonUDFFromJoinCondition.ruleName :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 32ac58f8353ab..6994aaf47dfba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -44,7 +44,6 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen JoinSelection :: InMemoryScans :: SparkScripts :: - WithCTEStrategy :: BasicOperators :: Nil) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8022d476ce015..1b8d347ed8a14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.physical.RoundRobinPartitioning import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.AggUtils @@ -678,36 +677,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - /** - * Strategy to plan CTE relations left not inlined. - */ - object WithCTEStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case WithCTE(plan, cteDefs) => - val cteMap = QueryExecution.cteMap - cteDefs.foreach { cteDef => - cteMap.put(cteDef.id, cteDef) - } - planLater(plan) :: Nil - - case r: CTERelationRef => - val ctePlan = QueryExecution.cteMap(r.cteId).child - val projectList = r.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) => - Alias(srcAttr, tgtAttr.name)(exprId = tgtAttr.exprId) - } - val newPlan = Project(projectList, ctePlan) - // Plan CTE ref as a repartition shuffle so that all refs of the same CTE def will share - // an Exchange reuse at runtime. - // TODO create a new identity partitioning instead of using RoundRobinPartitioning. - exchange.ShuffleExchangeExec( - RoundRobinPartitioning(conf.numShufflePartitions), - planLater(newPlan), - REPARTITION_BY_COL) :: Nil - - case _ => Nil - } - } - object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 808959363ac63..4a2740656688f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -150,9 +150,7 @@ case class AdaptiveSparkPlanExec( collapseCodegenStagesRule ) - private def optimizeQueryStage( - plan: SparkPlan, - isFinalStage: Boolean): SparkPlan = context.qe.withCteMap { + private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) => val applied = rule.apply(latestPlan) val result = rule match { @@ -643,8 +641,7 @@ case class AdaptiveSparkPlanExec( /** * Re-optimize and run physical planning on the current logical plan based on the latest stats. */ - private def reOptimize( - logicalPlan: LogicalPlan): (SparkPlan, LogicalPlan) = context.qe.withCteMap { + private def reOptimize(logicalPlan: LogicalPlan): (SparkPlan, LogicalPlan) = { logicalPlan.invalidateStatsCache() val optimized = optimizer.execute(logicalPlan) val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql index a76a010722090..4c80b268c20c3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql @@ -145,3 +145,45 @@ SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1; SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1; SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1; SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1; + +-- CTE in correlated scalar subqueries +CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t1(c1, c2); +CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (0, 2), (0, 3) t2(c1, c2); + +-- Single row subquery +SELECT c1, (WITH t AS (SELECT 1 AS a) SELECT a + c1 FROM t) FROM t1; +-- Correlation in CTE. +SELECT c1, (WITH t AS (SELECT * FROM t2 WHERE c1 = t1.c1) SELECT SUM(c2) FROM t) FROM t1; +-- Multiple CTE definitions. +SELECT c1, ( + WITH t3 AS (SELECT c1 + 1 AS c1, c2 + 1 AS c2 FROM t2), + t4 AS (SELECT * FROM t3 WHERE t1.c1 = c1) + SELECT SUM(c2) FROM t4 +) FROM t1; +-- Multiple CTE references. +SELECT c1, ( + WITH t AS (SELECT * FROM t2) + SELECT SUM(c2) FROM (SELECT c1, c2 FROM t UNION SELECT c2, c1 FROM t) r(c1, c2) + WHERE c1 = t1.c1 +) FROM t1; +-- Reference CTE in both the main query and the subquery. +WITH v AS (SELECT * FROM t2) +SELECT * FROM t1 WHERE c1 > ( + WITH t AS (SELECT * FROM t2) + SELECT COUNT(*) FROM v WHERE c1 = t1.c1 AND c1 > (SELECT SUM(c2) FROM t WHERE c1 = v.c1) +); +-- Single row subquery that references CTE in the main query. +WITH t AS (SELECT 1 AS a) +SELECT c1, (SELECT a FROM t WHERE a = c1) FROM t1; +-- Multiple CTE references with non-deterministic CTEs. +WITH +v1 AS (SELECT c1, c2, rand(0) c3 FROM t1), +v2 AS (SELECT c1, c2, rand(0) c4 FROM v1 WHERE c3 IN (SELECT c3 FROM v1)) +SELECT c1, ( + WITH v3 AS (SELECT c1, c2, rand(0) c5 FROM t2) + SELECT COUNT(*) FROM ( + SELECT * FROM v2 WHERE c1 > 0 + UNION SELECT * FROM v2 WHERE c2 > 0 + UNION SELECT * FROM v3 WHERE c2 > 0 + ) WHERE c1 = v1.c1 +) FROM v1; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out index 8fac940f8efd0..3eb1c6ffba187 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 26 -- !query @@ -317,3 +317,104 @@ val1d NULL val1e 8 val1e 8 val1e 8 + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t1(c1, c2) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (0, 2), (0, 3) t2(c1, c2) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT c1, (WITH t AS (SELECT 1 AS a) SELECT a + c1 FROM t) FROM t1 +-- !query schema +struct +-- !query output +0 1 +1 2 + + +-- !query +SELECT c1, (WITH t AS (SELECT * FROM t2 WHERE c1 = t1.c1) SELECT SUM(c2) FROM t) FROM t1 +-- !query schema +struct +-- !query output +0 5 +1 NULL + + +-- !query +SELECT c1, ( + WITH t3 AS (SELECT c1 + 1 AS c1, c2 + 1 AS c2 FROM t2), + t4 AS (SELECT * FROM t3 WHERE t1.c1 = c1) + SELECT SUM(c2) FROM t4 +) FROM t1 +-- !query schema +struct +-- !query output +0 NULL +1 7 + + +-- !query +SELECT c1, ( + WITH t AS (SELECT * FROM t2) + SELECT SUM(c2) FROM (SELECT c1, c2 FROM t UNION SELECT c2, c1 FROM t) r(c1, c2) + WHERE c1 = t1.c1 +) FROM t1 +-- !query schema +struct +-- !query output +0 5 +1 NULL + + +-- !query +WITH v AS (SELECT * FROM t2) +SELECT * FROM t1 WHERE c1 > ( + WITH t AS (SELECT * FROM t2) + SELECT COUNT(*) FROM v WHERE c1 = t1.c1 AND c1 > (SELECT SUM(c2) FROM t WHERE c1 = v.c1) +) +-- !query schema +struct +-- !query output +1 2 + + +-- !query +WITH t AS (SELECT 1 AS a) +SELECT c1, (SELECT a FROM t WHERE a = c1) FROM t1 +-- !query schema +struct +-- !query output +0 NULL +1 1 + + +-- !query +WITH +v1 AS (SELECT c1, c2, rand(0) c3 FROM t1), +v2 AS (SELECT c1, c2, rand(0) c4 FROM v1 WHERE c3 IN (SELECT c3 FROM v1)) +SELECT c1, ( + WITH v3 AS (SELECT c1, c2, rand(0) c5 FROM t2) + SELECT COUNT(*) FROM ( + SELECT * FROM v2 WHERE c1 > 0 + UNION SELECT * FROM v2 WHERE c2 > 0 + UNION SELECT * FROM v3 WHERE c2 > 0 + ) WHERE c1 = v1.c1 +) FROM v1 +-- !query schema +struct +-- !query output +0 3 +1 1 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt index 5bf5193487b07..7f419ce3eaf6d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt @@ -360,19 +360,19 @@ Right keys [1]: [i_item_sk#14] Join condition: None (61) Project [codegen id : 25] -Output [3]: [d_date#12, i_item_sk#14, substr(i_item_desc#15, 1, 30) AS _groupingexpression#47] +Output [3]: [d_date#12, i_item_sk#14, substr(i_item_desc#15, 1, 30) AS _groupingexpression#17] Input [4]: [ss_item_sk#8, d_date#12, i_item_sk#14, i_item_desc#15] (62) HashAggregate [codegen id : 25] -Input [3]: [d_date#12, i_item_sk#14, _groupingexpression#47] -Keys [3]: [_groupingexpression#47, i_item_sk#14, d_date#12] +Input [3]: [d_date#12, i_item_sk#14, _groupingexpression#17] +Keys [3]: [_groupingexpression#17, i_item_sk#14, d_date#12] Functions [1]: [partial_count(1)] Aggregate Attributes [1]: [count#18] -Results [4]: [_groupingexpression#47, i_item_sk#14, d_date#12, count#19] +Results [4]: [_groupingexpression#17, i_item_sk#14, d_date#12, count#19] (63) HashAggregate [codegen id : 25] -Input [4]: [_groupingexpression#47, i_item_sk#14, d_date#12, count#19] -Keys [3]: [_groupingexpression#47, i_item_sk#14, d_date#12] +Input [4]: [_groupingexpression#17, i_item_sk#14, d_date#12, count#19] +Keys [3]: [_groupingexpression#17, i_item_sk#14, d_date#12] Functions [1]: [count(1)] Aggregate Attributes [1]: [count(1)#20] Results [2]: [i_item_sk#14 AS item_sk#21, count(1)#20 AS cnt#22] @@ -400,7 +400,7 @@ Input [5]: [ws_item_sk#41, ws_bill_customer_sk#42, ws_quantity#43, ws_list_price (69) Exchange Input [4]: [ws_bill_customer_sk#42, ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45] -Arguments: hashpartitioning(ws_bill_customer_sk#42, 5), ENSURE_REQUIREMENTS, [id=#48] +Arguments: hashpartitioning(ws_bill_customer_sk#42, 5), ENSURE_REQUIREMENTS, [id=#47] (70) Sort [codegen id : 27] Input [4]: [ws_bill_customer_sk#42, ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45] @@ -433,11 +433,11 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29] Keys [1]: [c_customer_sk#29] Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] -Aggregate Attributes [2]: [sum#49, isEmpty#50] -Results [3]: [c_customer_sk#29, sum#51, isEmpty#52] +Aggregate Attributes [2]: [sum#48, isEmpty#49] +Results [3]: [c_customer_sk#29, sum#50, isEmpty#51] (78) HashAggregate [codegen id : 32] -Input [3]: [c_customer_sk#29, sum#51, isEmpty#52] +Input [3]: [c_customer_sk#29, sum#50, isEmpty#51] Keys [1]: [c_customer_sk#29] Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] @@ -465,16 +465,16 @@ Output [3]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45] Input [4]: [ws_bill_customer_sk#42, ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45] (84) ReusedExchange [Reuses operator id: 95] -Output [1]: [d_date_sk#53] +Output [1]: [d_date_sk#52] (85) BroadcastHashJoin [codegen id : 34] Left keys [1]: [ws_sold_date_sk#45] -Right keys [1]: [d_date_sk#53] +Right keys [1]: [d_date_sk#52] Join condition: None (86) Project [codegen id : 34] -Output [1]: [CheckOverflow((promote_precision(cast(ws_quantity#43 as decimal(12,2))) * promote_precision(cast(ws_list_price#44 as decimal(12,2)))), DecimalType(18,2)) AS sales#54] -Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#53] +Output [1]: [CheckOverflow((promote_precision(cast(ws_quantity#43 as decimal(12,2))) * promote_precision(cast(ws_list_price#44 as decimal(12,2)))), DecimalType(18,2)) AS sales#53] +Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#52] (87) Union @@ -482,19 +482,19 @@ Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#53] Input [1]: [sales#40] Keys: [] Functions [1]: [partial_sum(sales#40)] -Aggregate Attributes [2]: [sum#55, isEmpty#56] -Results [2]: [sum#57, isEmpty#58] +Aggregate Attributes [2]: [sum#54, isEmpty#55] +Results [2]: [sum#56, isEmpty#57] (89) Exchange -Input [2]: [sum#57, isEmpty#58] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#59] +Input [2]: [sum#56, isEmpty#57] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#58] (90) HashAggregate [codegen id : 36] -Input [2]: [sum#57, isEmpty#58] +Input [2]: [sum#56, isEmpty#57] Keys: [] Functions [1]: [sum(sales#40)] -Aggregate Attributes [1]: [sum(sales#40)#60] -Results [1]: [sum(sales#40)#60 AS sum(sales)#61] +Aggregate Attributes [1]: [sum(sales#40)#59] +Results [1]: [sum(sales#40)#59 AS sum(sales)#60] ===== Subqueries ===== @@ -507,26 +507,26 @@ BroadcastExchange (95) (91) Scan parquet default.date_dim -Output [3]: [d_date_sk#39, d_year#62, d_moy#63] +Output [3]: [d_date_sk#39, d_year#61, d_moy#62] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), EqualTo(d_moy,2), IsNotNull(d_date_sk)] ReadSchema: struct (92) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#62, d_moy#63] +Input [3]: [d_date_sk#39, d_year#61, d_moy#62] (93) Filter [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#62, d_moy#63] -Condition : ((((isnotnull(d_year#62) AND isnotnull(d_moy#63)) AND (d_year#62 = 2000)) AND (d_moy#63 = 2)) AND isnotnull(d_date_sk#39)) +Input [3]: [d_date_sk#39, d_year#61, d_moy#62] +Condition : ((((isnotnull(d_year#61) AND isnotnull(d_moy#62)) AND (d_year#61 = 2000)) AND (d_moy#62 = 2)) AND isnotnull(d_date_sk#39)) (94) Project [codegen id : 1] Output [1]: [d_date_sk#39] -Input [3]: [d_date_sk#39, d_year#62, d_moy#63] +Input [3]: [d_date_sk#39, d_year#61, d_moy#62] (95) BroadcastExchange Input [1]: [d_date_sk#39] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#64] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#63] Subquery:2 Hosting operator id = 5 Hosting Expression = ss_sold_date_sk#9 IN dynamicpruning#10 BroadcastExchange (100) @@ -537,26 +537,26 @@ BroadcastExchange (100) (96) Scan parquet default.date_dim -Output [3]: [d_date_sk#11, d_date#12, d_year#65] +Output [3]: [d_date_sk#11, d_date#12, d_year#64] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (97) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#11, d_date#12, d_year#65] +Input [3]: [d_date_sk#11, d_date#12, d_year#64] (98) Filter [codegen id : 1] -Input [3]: [d_date_sk#11, d_date#12, d_year#65] -Condition : (d_year#65 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11)) +Input [3]: [d_date_sk#11, d_date#12, d_year#64] +Condition : (d_year#64 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11)) (99) Project [codegen id : 1] Output [2]: [d_date_sk#11, d_date#12] -Input [3]: [d_date_sk#11, d_date#12, d_year#65] +Input [3]: [d_date_sk#11, d_date#12, d_year#64] (100) BroadcastExchange Input [2]: [d_date_sk#11, d_date#12] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#66] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#65] Subquery:3 Hosting operator id = 44 Hosting Expression = Subquery scalar-subquery#37, [id=#38] * HashAggregate (117) @@ -579,89 +579,89 @@ Subquery:3 Hosting operator id = 44 Hosting Expression = Subquery scalar-subquer (101) Scan parquet default.store_sales -Output [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, ss_sold_date_sk#70] +Output [4]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, ss_sold_date_sk#69] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#70), dynamicpruningexpression(ss_sold_date_sk#70 IN dynamicpruning#71)] +PartitionFilters: [isnotnull(ss_sold_date_sk#69), dynamicpruningexpression(ss_sold_date_sk#69 IN dynamicpruning#70)] PushedFilters: [IsNotNull(ss_customer_sk)] ReadSchema: struct (102) ColumnarToRow [codegen id : 2] -Input [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, ss_sold_date_sk#70] +Input [4]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, ss_sold_date_sk#69] (103) Filter [codegen id : 2] -Input [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, ss_sold_date_sk#70] -Condition : isnotnull(ss_customer_sk#67) +Input [4]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, ss_sold_date_sk#69] +Condition : isnotnull(ss_customer_sk#66) (104) ReusedExchange [Reuses operator id: 122] -Output [1]: [d_date_sk#72] +Output [1]: [d_date_sk#71] (105) BroadcastHashJoin [codegen id : 2] -Left keys [1]: [ss_sold_date_sk#70] -Right keys [1]: [d_date_sk#72] +Left keys [1]: [ss_sold_date_sk#69] +Right keys [1]: [d_date_sk#71] Join condition: None (106) Project [codegen id : 2] -Output [3]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69] -Input [5]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, ss_sold_date_sk#70, d_date_sk#72] +Output [3]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68] +Input [5]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, ss_sold_date_sk#69, d_date_sk#71] (107) Exchange -Input [3]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69] -Arguments: hashpartitioning(ss_customer_sk#67, 5), ENSURE_REQUIREMENTS, [id=#73] +Input [3]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68] +Arguments: hashpartitioning(ss_customer_sk#66, 5), ENSURE_REQUIREMENTS, [id=#72] (108) Sort [codegen id : 3] -Input [3]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69] -Arguments: [ss_customer_sk#67 ASC NULLS FIRST], false, 0 +Input [3]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68] +Arguments: [ss_customer_sk#66 ASC NULLS FIRST], false, 0 (109) ReusedExchange [Reuses operator id: 38] -Output [1]: [c_customer_sk#74] +Output [1]: [c_customer_sk#73] (110) Sort [codegen id : 5] -Input [1]: [c_customer_sk#74] -Arguments: [c_customer_sk#74 ASC NULLS FIRST], false, 0 +Input [1]: [c_customer_sk#73] +Arguments: [c_customer_sk#73 ASC NULLS FIRST], false, 0 (111) SortMergeJoin [codegen id : 6] -Left keys [1]: [ss_customer_sk#67] -Right keys [1]: [c_customer_sk#74] +Left keys [1]: [ss_customer_sk#66] +Right keys [1]: [c_customer_sk#73] Join condition: None (112) Project [codegen id : 6] -Output [3]: [ss_quantity#68, ss_sales_price#69, c_customer_sk#74] -Input [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, c_customer_sk#74] +Output [3]: [ss_quantity#67, ss_sales_price#68, c_customer_sk#73] +Input [4]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, c_customer_sk#73] (113) HashAggregate [codegen id : 6] -Input [3]: [ss_quantity#68, ss_sales_price#69, c_customer_sk#74] -Keys [1]: [c_customer_sk#74] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), DecimalType(18,2)))] -Aggregate Attributes [2]: [sum#75, isEmpty#76] -Results [3]: [c_customer_sk#74, sum#77, isEmpty#78] +Input [3]: [ss_quantity#67, ss_sales_price#68, c_customer_sk#73] +Keys [1]: [c_customer_sk#73] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#67 as decimal(12,2))) * promote_precision(cast(ss_sales_price#68 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#74, isEmpty#75] +Results [3]: [c_customer_sk#73, sum#76, isEmpty#77] (114) HashAggregate [codegen id : 6] -Input [3]: [c_customer_sk#74, sum#77, isEmpty#78] -Keys [1]: [c_customer_sk#74] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), DecimalType(18,2)))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), DecimalType(18,2)))#79] -Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), DecimalType(18,2)))#79 AS csales#80] +Input [3]: [c_customer_sk#73, sum#76, isEmpty#77] +Keys [1]: [c_customer_sk#73] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#67 as decimal(12,2))) * promote_precision(cast(ss_sales_price#68 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#67 as decimal(12,2))) * promote_precision(cast(ss_sales_price#68 as decimal(12,2)))), DecimalType(18,2)))#78] +Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#67 as decimal(12,2))) * promote_precision(cast(ss_sales_price#68 as decimal(12,2)))), DecimalType(18,2)))#78 AS csales#79] (115) HashAggregate [codegen id : 6] -Input [1]: [csales#80] +Input [1]: [csales#79] Keys: [] -Functions [1]: [partial_max(csales#80)] -Aggregate Attributes [1]: [max#81] -Results [1]: [max#82] +Functions [1]: [partial_max(csales#79)] +Aggregate Attributes [1]: [max#80] +Results [1]: [max#81] (116) Exchange -Input [1]: [max#82] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#83] +Input [1]: [max#81] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#82] (117) HashAggregate [codegen id : 7] -Input [1]: [max#82] +Input [1]: [max#81] Keys: [] -Functions [1]: [max(csales#80)] -Aggregate Attributes [1]: [max(csales#80)#84] -Results [1]: [max(csales#80)#84 AS tpcds_cmax#85] +Functions [1]: [max(csales#79)] +Aggregate Attributes [1]: [max(csales#79)#83] +Results [1]: [max(csales#79)#83 AS tpcds_cmax#84] -Subquery:4 Hosting operator id = 101 Hosting Expression = ss_sold_date_sk#70 IN dynamicpruning#71 +Subquery:4 Hosting operator id = 101 Hosting Expression = ss_sold_date_sk#69 IN dynamicpruning#70 BroadcastExchange (122) +- * Project (121) +- * Filter (120) @@ -670,26 +670,26 @@ BroadcastExchange (122) (118) Scan parquet default.date_dim -Output [2]: [d_date_sk#72, d_year#86] +Output [2]: [d_date_sk#71, d_year#85] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (119) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#72, d_year#86] +Input [2]: [d_date_sk#71, d_year#85] (120) Filter [codegen id : 1] -Input [2]: [d_date_sk#72, d_year#86] -Condition : (d_year#86 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#72)) +Input [2]: [d_date_sk#71, d_year#85] +Condition : (d_year#85 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#71)) (121) Project [codegen id : 1] -Output [1]: [d_date_sk#72] -Input [2]: [d_date_sk#72, d_year#86] +Output [1]: [d_date_sk#71] +Input [2]: [d_date_sk#71, d_year#85] (122) BroadcastExchange -Input [1]: [d_date_sk#72] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#87] +Input [1]: [d_date_sk#71] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#86] Subquery:5 Hosting operator id = 52 Hosting Expression = ws_sold_date_sk#45 IN dynamicpruning#6 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt index 3de1f24613451..4d1109078e346 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt @@ -508,19 +508,19 @@ Right keys [1]: [i_item_sk#14] Join condition: None (84) Project [codegen id : 35] -Output [3]: [d_date#12, i_item_sk#14, substr(i_item_desc#15, 1, 30) AS _groupingexpression#57] +Output [3]: [d_date#12, i_item_sk#14, substr(i_item_desc#15, 1, 30) AS _groupingexpression#17] Input [4]: [ss_item_sk#8, d_date#12, i_item_sk#14, i_item_desc#15] (85) HashAggregate [codegen id : 35] -Input [3]: [d_date#12, i_item_sk#14, _groupingexpression#57] -Keys [3]: [_groupingexpression#57, i_item_sk#14, d_date#12] +Input [3]: [d_date#12, i_item_sk#14, _groupingexpression#17] +Keys [3]: [_groupingexpression#17, i_item_sk#14, d_date#12] Functions [1]: [partial_count(1)] Aggregate Attributes [1]: [count#18] -Results [4]: [_groupingexpression#57, i_item_sk#14, d_date#12, count#19] +Results [4]: [_groupingexpression#17, i_item_sk#14, d_date#12, count#19] (86) HashAggregate [codegen id : 35] -Input [4]: [_groupingexpression#57, i_item_sk#14, d_date#12, count#19] -Keys [3]: [_groupingexpression#57, i_item_sk#14, d_date#12] +Input [4]: [_groupingexpression#17, i_item_sk#14, d_date#12, count#19] +Keys [3]: [_groupingexpression#17, i_item_sk#14, d_date#12] Functions [1]: [count(1)] Aggregate Attributes [1]: [count(1)#20] Results [2]: [i_item_sk#14 AS item_sk#21, count(1)#20 AS cnt#22] @@ -548,7 +548,7 @@ Input [5]: [ws_item_sk#51, ws_bill_customer_sk#52, ws_quantity#53, ws_list_price (92) Exchange Input [4]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, ws_sold_date_sk#55] -Arguments: hashpartitioning(ws_bill_customer_sk#52, 5), ENSURE_REQUIREMENTS, [id=#58] +Arguments: hashpartitioning(ws_bill_customer_sk#52, 5), ENSURE_REQUIREMENTS, [id=#57] (93) Sort [codegen id : 37] Input [4]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, ws_sold_date_sk#55] @@ -581,11 +581,11 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29] Keys [1]: [c_customer_sk#29] Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] -Aggregate Attributes [2]: [sum#59, isEmpty#60] -Results [3]: [c_customer_sk#29, sum#61, isEmpty#62] +Aggregate Attributes [2]: [sum#58, isEmpty#59] +Results [3]: [c_customer_sk#29, sum#60, isEmpty#61] (101) HashAggregate [codegen id : 42] -Input [3]: [c_customer_sk#29, sum#61, isEmpty#62] +Input [3]: [c_customer_sk#29, sum#60, isEmpty#61] Keys [1]: [c_customer_sk#29] Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] @@ -609,23 +609,23 @@ Right keys [1]: [c_customer_sk#29] Join condition: None (106) ReusedExchange [Reuses operator id: 134] -Output [1]: [d_date_sk#63] +Output [1]: [d_date_sk#62] (107) BroadcastHashJoin [codegen id : 44] Left keys [1]: [ws_sold_date_sk#55] -Right keys [1]: [d_date_sk#63] +Right keys [1]: [d_date_sk#62] Join condition: None (108) Project [codegen id : 44] Output [3]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54] -Input [5]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, ws_sold_date_sk#55, d_date_sk#63] +Input [5]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, ws_sold_date_sk#55, d_date_sk#62] (109) ReusedExchange [Reuses operator id: 55] -Output [3]: [c_customer_sk#64, c_first_name#65, c_last_name#66] +Output [3]: [c_customer_sk#63, c_first_name#64, c_last_name#65] (110) Sort [codegen id : 46] -Input [3]: [c_customer_sk#64, c_first_name#65, c_last_name#66] -Arguments: [c_customer_sk#64 ASC NULLS FIRST], false, 0 +Input [3]: [c_customer_sk#63, c_first_name#64, c_last_name#65] +Arguments: [c_customer_sk#63 ASC NULLS FIRST], false, 0 (111) ReusedExchange [Reuses operator id: 34] Output [3]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26] @@ -654,11 +654,11 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26, c_customer_sk# Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29] Keys [1]: [c_customer_sk#29] Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] -Aggregate Attributes [2]: [sum#59, isEmpty#60] -Results [3]: [c_customer_sk#29, sum#61, isEmpty#62] +Aggregate Attributes [2]: [sum#58, isEmpty#59] +Results [3]: [c_customer_sk#29, sum#60, isEmpty#61] (118) HashAggregate [codegen id : 51] -Input [3]: [c_customer_sk#29, sum#61, isEmpty#62] +Input [3]: [c_customer_sk#29, sum#60, isEmpty#61] Keys [1]: [c_customer_sk#29] Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))] Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), DecimalType(18,2)))#35] @@ -677,36 +677,36 @@ Input [1]: [c_customer_sk#29] Arguments: [c_customer_sk#29 ASC NULLS FIRST], false, 0 (122) SortMergeJoin [codegen id : 52] -Left keys [1]: [c_customer_sk#64] +Left keys [1]: [c_customer_sk#63] Right keys [1]: [c_customer_sk#29] Join condition: None (123) SortMergeJoin [codegen id : 53] Left keys [1]: [ws_bill_customer_sk#52] -Right keys [1]: [c_customer_sk#64] +Right keys [1]: [c_customer_sk#63] Join condition: None (124) Project [codegen id : 53] -Output [4]: [ws_quantity#53, ws_list_price#54, c_first_name#65, c_last_name#66] -Input [6]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, c_customer_sk#64, c_first_name#65, c_last_name#66] +Output [4]: [ws_quantity#53, ws_list_price#54, c_first_name#64, c_last_name#65] +Input [6]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, c_customer_sk#63, c_first_name#64, c_last_name#65] (125) HashAggregate [codegen id : 53] -Input [4]: [ws_quantity#53, ws_list_price#54, c_first_name#65, c_last_name#66] -Keys [2]: [c_last_name#66, c_first_name#65] +Input [4]: [ws_quantity#53, ws_list_price#54, c_first_name#64, c_last_name#65] +Keys [2]: [c_last_name#65, c_first_name#64] Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))] -Aggregate Attributes [2]: [sum#67, isEmpty#68] -Results [4]: [c_last_name#66, c_first_name#65, sum#69, isEmpty#70] +Aggregate Attributes [2]: [sum#66, isEmpty#67] +Results [4]: [c_last_name#65, c_first_name#64, sum#68, isEmpty#69] (126) Exchange -Input [4]: [c_last_name#66, c_first_name#65, sum#69, isEmpty#70] -Arguments: hashpartitioning(c_last_name#66, c_first_name#65, 5), ENSURE_REQUIREMENTS, [id=#71] +Input [4]: [c_last_name#65, c_first_name#64, sum#68, isEmpty#69] +Arguments: hashpartitioning(c_last_name#65, c_first_name#64, 5), ENSURE_REQUIREMENTS, [id=#70] (127) HashAggregate [codegen id : 54] -Input [4]: [c_last_name#66, c_first_name#65, sum#69, isEmpty#70] -Keys [2]: [c_last_name#66, c_first_name#65] +Input [4]: [c_last_name#65, c_first_name#64, sum#68, isEmpty#69] +Keys [2]: [c_last_name#65, c_first_name#64] Functions [1]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))#72] -Results [3]: [c_last_name#66, c_first_name#65, sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))#72 AS sales#73] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))#71] +Results [3]: [c_last_name#65, c_first_name#64, sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), DecimalType(18,2)))#71 AS sales#72] (128) Union @@ -725,26 +725,26 @@ BroadcastExchange (134) (130) Scan parquet default.date_dim -Output [3]: [d_date_sk#39, d_year#74, d_moy#75] +Output [3]: [d_date_sk#39, d_year#73, d_moy#74] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), EqualTo(d_moy,2), IsNotNull(d_date_sk)] ReadSchema: struct (131) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#74, d_moy#75] +Input [3]: [d_date_sk#39, d_year#73, d_moy#74] (132) Filter [codegen id : 1] -Input [3]: [d_date_sk#39, d_year#74, d_moy#75] -Condition : ((((isnotnull(d_year#74) AND isnotnull(d_moy#75)) AND (d_year#74 = 2000)) AND (d_moy#75 = 2)) AND isnotnull(d_date_sk#39)) +Input [3]: [d_date_sk#39, d_year#73, d_moy#74] +Condition : ((((isnotnull(d_year#73) AND isnotnull(d_moy#74)) AND (d_year#73 = 2000)) AND (d_moy#74 = 2)) AND isnotnull(d_date_sk#39)) (133) Project [codegen id : 1] Output [1]: [d_date_sk#39] -Input [3]: [d_date_sk#39, d_year#74, d_moy#75] +Input [3]: [d_date_sk#39, d_year#73, d_moy#74] (134) BroadcastExchange Input [1]: [d_date_sk#39] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#76] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#75] Subquery:2 Hosting operator id = 6 Hosting Expression = ss_sold_date_sk#9 IN dynamicpruning#10 BroadcastExchange (139) @@ -755,26 +755,26 @@ BroadcastExchange (139) (135) Scan parquet default.date_dim -Output [3]: [d_date_sk#11, d_date#12, d_year#77] +Output [3]: [d_date_sk#11, d_date#12, d_year#76] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (136) ColumnarToRow [codegen id : 1] -Input [3]: [d_date_sk#11, d_date#12, d_year#77] +Input [3]: [d_date_sk#11, d_date#12, d_year#76] (137) Filter [codegen id : 1] -Input [3]: [d_date_sk#11, d_date#12, d_year#77] -Condition : (d_year#77 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11)) +Input [3]: [d_date_sk#11, d_date#12, d_year#76] +Condition : (d_year#76 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11)) (138) Project [codegen id : 1] Output [2]: [d_date_sk#11, d_date#12] -Input [3]: [d_date_sk#11, d_date#12, d_year#77] +Input [3]: [d_date_sk#11, d_date#12, d_year#76] (139) BroadcastExchange Input [2]: [d_date_sk#11, d_date#12] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#78] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#77] Subquery:3 Hosting operator id = 45 Hosting Expression = Subquery scalar-subquery#37, [id=#38] * HashAggregate (156) @@ -797,89 +797,89 @@ Subquery:3 Hosting operator id = 45 Hosting Expression = Subquery scalar-subquer (140) Scan parquet default.store_sales -Output [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, ss_sold_date_sk#82] +Output [4]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, ss_sold_date_sk#81] Batched: true Location: InMemoryFileIndex [] -PartitionFilters: [isnotnull(ss_sold_date_sk#82), dynamicpruningexpression(ss_sold_date_sk#82 IN dynamicpruning#83)] +PartitionFilters: [isnotnull(ss_sold_date_sk#81), dynamicpruningexpression(ss_sold_date_sk#81 IN dynamicpruning#82)] PushedFilters: [IsNotNull(ss_customer_sk)] ReadSchema: struct (141) ColumnarToRow [codegen id : 2] -Input [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, ss_sold_date_sk#82] +Input [4]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, ss_sold_date_sk#81] (142) Filter [codegen id : 2] -Input [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, ss_sold_date_sk#82] -Condition : isnotnull(ss_customer_sk#79) +Input [4]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, ss_sold_date_sk#81] +Condition : isnotnull(ss_customer_sk#78) (143) ReusedExchange [Reuses operator id: 161] -Output [1]: [d_date_sk#84] +Output [1]: [d_date_sk#83] (144) BroadcastHashJoin [codegen id : 2] -Left keys [1]: [ss_sold_date_sk#82] -Right keys [1]: [d_date_sk#84] +Left keys [1]: [ss_sold_date_sk#81] +Right keys [1]: [d_date_sk#83] Join condition: None (145) Project [codegen id : 2] -Output [3]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81] -Input [5]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, ss_sold_date_sk#82, d_date_sk#84] +Output [3]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80] +Input [5]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, ss_sold_date_sk#81, d_date_sk#83] (146) Exchange -Input [3]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81] -Arguments: hashpartitioning(ss_customer_sk#79, 5), ENSURE_REQUIREMENTS, [id=#85] +Input [3]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80] +Arguments: hashpartitioning(ss_customer_sk#78, 5), ENSURE_REQUIREMENTS, [id=#84] (147) Sort [codegen id : 3] -Input [3]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81] -Arguments: [ss_customer_sk#79 ASC NULLS FIRST], false, 0 +Input [3]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80] +Arguments: [ss_customer_sk#78 ASC NULLS FIRST], false, 0 (148) ReusedExchange [Reuses operator id: 39] -Output [1]: [c_customer_sk#86] +Output [1]: [c_customer_sk#85] (149) Sort [codegen id : 5] -Input [1]: [c_customer_sk#86] -Arguments: [c_customer_sk#86 ASC NULLS FIRST], false, 0 +Input [1]: [c_customer_sk#85] +Arguments: [c_customer_sk#85 ASC NULLS FIRST], false, 0 (150) SortMergeJoin [codegen id : 6] -Left keys [1]: [ss_customer_sk#79] -Right keys [1]: [c_customer_sk#86] +Left keys [1]: [ss_customer_sk#78] +Right keys [1]: [c_customer_sk#85] Join condition: None (151) Project [codegen id : 6] -Output [3]: [ss_quantity#80, ss_sales_price#81, c_customer_sk#86] -Input [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, c_customer_sk#86] +Output [3]: [ss_quantity#79, ss_sales_price#80, c_customer_sk#85] +Input [4]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, c_customer_sk#85] (152) HashAggregate [codegen id : 6] -Input [3]: [ss_quantity#80, ss_sales_price#81, c_customer_sk#86] -Keys [1]: [c_customer_sk#86] -Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), DecimalType(18,2)))] -Aggregate Attributes [2]: [sum#87, isEmpty#88] -Results [3]: [c_customer_sk#86, sum#89, isEmpty#90] +Input [3]: [ss_quantity#79, ss_sales_price#80, c_customer_sk#85] +Keys [1]: [c_customer_sk#85] +Functions [1]: [partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#79 as decimal(12,2))) * promote_precision(cast(ss_sales_price#80 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [2]: [sum#86, isEmpty#87] +Results [3]: [c_customer_sk#85, sum#88, isEmpty#89] (153) HashAggregate [codegen id : 6] -Input [3]: [c_customer_sk#86, sum#89, isEmpty#90] -Keys [1]: [c_customer_sk#86] -Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), DecimalType(18,2)))] -Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), DecimalType(18,2)))#91] -Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), DecimalType(18,2)))#91 AS csales#92] +Input [3]: [c_customer_sk#85, sum#88, isEmpty#89] +Keys [1]: [c_customer_sk#85] +Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#79 as decimal(12,2))) * promote_precision(cast(ss_sales_price#80 as decimal(12,2)))), DecimalType(18,2)))] +Aggregate Attributes [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#79 as decimal(12,2))) * promote_precision(cast(ss_sales_price#80 as decimal(12,2)))), DecimalType(18,2)))#90] +Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#79 as decimal(12,2))) * promote_precision(cast(ss_sales_price#80 as decimal(12,2)))), DecimalType(18,2)))#90 AS csales#91] (154) HashAggregate [codegen id : 6] -Input [1]: [csales#92] +Input [1]: [csales#91] Keys: [] -Functions [1]: [partial_max(csales#92)] -Aggregate Attributes [1]: [max#93] -Results [1]: [max#94] +Functions [1]: [partial_max(csales#91)] +Aggregate Attributes [1]: [max#92] +Results [1]: [max#93] (155) Exchange -Input [1]: [max#94] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#95] +Input [1]: [max#93] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#94] (156) HashAggregate [codegen id : 7] -Input [1]: [max#94] +Input [1]: [max#93] Keys: [] -Functions [1]: [max(csales#92)] -Aggregate Attributes [1]: [max(csales#92)#96] -Results [1]: [max(csales#92)#96 AS tpcds_cmax#97] +Functions [1]: [max(csales#91)] +Aggregate Attributes [1]: [max(csales#91)#95] +Results [1]: [max(csales#91)#95 AS tpcds_cmax#96] -Subquery:4 Hosting operator id = 140 Hosting Expression = ss_sold_date_sk#82 IN dynamicpruning#83 +Subquery:4 Hosting operator id = 140 Hosting Expression = ss_sold_date_sk#81 IN dynamicpruning#82 BroadcastExchange (161) +- * Project (160) +- * Filter (159) @@ -888,26 +888,26 @@ BroadcastExchange (161) (157) Scan parquet default.date_dim -Output [2]: [d_date_sk#84, d_year#98] +Output [2]: [d_date_sk#83, d_year#97] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)] ReadSchema: struct (158) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#84, d_year#98] +Input [2]: [d_date_sk#83, d_year#97] (159) Filter [codegen id : 1] -Input [2]: [d_date_sk#84, d_year#98] -Condition : (d_year#98 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#84)) +Input [2]: [d_date_sk#83, d_year#97] +Condition : (d_year#97 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#83)) (160) Project [codegen id : 1] -Output [1]: [d_date_sk#84] -Input [2]: [d_date_sk#84, d_year#98] +Output [1]: [d_date_sk#83] +Input [2]: [d_date_sk#83, d_year#97] (161) BroadcastExchange -Input [1]: [d_date_sk#84] -Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#99] +Input [1]: [d_date_sk#83] +Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#98] Subquery:5 Hosting operator id = 65 Hosting Expression = ReusedSubquery Subquery scalar-subquery#37, [id=#38] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index dd30ff68da417..7d45102ac83d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.WithCTE +import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, LessThan, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project, RepartitionOperation, WithCTE} import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.internal.SQLConf @@ -42,7 +43,7 @@ abstract class CTEInlineSuiteBase """.stripMargin) checkAnswer(df, Nil) assert( - df.queryExecution.optimizedPlan.exists(_.isInstanceOf[WithCTE]), + df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]), "Non-deterministic With-CTE with multiple references should be not inlined.") } } @@ -59,7 +60,7 @@ abstract class CTEInlineSuiteBase """.stripMargin) checkAnswer(df, Nil) assert( - df.queryExecution.optimizedPlan.exists(_.isInstanceOf[WithCTE]), + df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]), "Non-deterministic With-CTE with multiple references should be not inlined.") } } @@ -79,7 +80,7 @@ abstract class CTEInlineSuiteBase df.queryExecution.analyzed.exists(_.isInstanceOf[WithCTE]), "With-CTE should not be inlined in analyzed plan.") assert( - !df.queryExecution.optimizedPlan.exists(_.isInstanceOf[WithCTE]), + !df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]), "With-CTE with one reference should be inlined in optimized plan.") } } @@ -107,8 +108,8 @@ abstract class CTEInlineSuiteBase "With-CTE should contain 2 CTE defs after analysis.") assert( df.queryExecution.optimizedPlan.collect { - case WithCTE(_, cteDefs) => cteDefs - }.head.length == 2, + case r: RepartitionOperation => r + }.length == 6, "With-CTE should contain 2 CTE def after optimization.") } } @@ -136,8 +137,8 @@ abstract class CTEInlineSuiteBase "With-CTE should contain 2 CTE defs after analysis.") assert( df.queryExecution.optimizedPlan.collect { - case WithCTE(_, cteDefs) => cteDefs - }.head.length == 1, + case r: RepartitionOperation => r + }.length == 4, "One CTE def should be inlined after optimization.") } } @@ -163,7 +164,7 @@ abstract class CTEInlineSuiteBase "With-CTE should contain 2 CTE defs after analysis.") assert( df.queryExecution.optimizedPlan.collect { - case WithCTE(_, cteDefs) => cteDefs + case r: RepartitionOperation => r }.isEmpty, "CTEs with one reference should all be inlined after optimization.") } @@ -248,7 +249,7 @@ abstract class CTEInlineSuiteBase "With-CTE should contain 2 CTE defs after analysis.") assert( df.queryExecution.optimizedPlan.collect { - case WithCTE(_, cteDefs) => cteDefs + case r: RepartitionOperation => r }.isEmpty, "Deterministic CTEs should all be inlined after optimization.") } @@ -272,6 +273,214 @@ abstract class CTEInlineSuiteBase assert(ex.message.contains("Table or view not found: v1")) } } + + test("CTE Predicate push-down and column pruning") { + withView("t") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + val df = sql( + s"""with + |v as ( + | select c1, c2, 's' c3, rand() c4 from t + |), + |vv as ( + | select v1.c1, v1.c2, rand() c5 from v v1, v v2 + | where v1.c1 > 0 and v1.c3 = 's' and v1.c2 = v2.c2 + |) + |select vv1.c1, vv1.c2, vv2.c1, vv2.c2 from vv vv1, vv vv2 + |where vv1.c2 > 0 and vv2.c2 > 0 and vv1.c1 = vv2.c1 + """.stripMargin) + checkAnswer(df, Row(1, 2, 1, 2) :: Nil) + assert( + df.queryExecution.analyzed.collect { + case WithCTE(_, cteDefs) => cteDefs + }.head.length == 2, + "With-CTE should contain 2 CTE defs after analysis.") + val cteRepartitions = df.queryExecution.optimizedPlan.collect { + case r: RepartitionOperation => r + } + assert(cteRepartitions.length == 6, + "CTE should not be inlined after optimization.") + val distinctCteRepartitions = cteRepartitions.map(_.canonicalized).distinct + // Check column pruning and predicate push-down. + assert(distinctCteRepartitions.length == 2) + assert(distinctCteRepartitions(1).collectFirst { + case p: Project if p.projectList.length == 3 => p + }.isDefined, "CTE columns should be pruned.") + assert(distinctCteRepartitions(1).collectFirst { + case f: Filter if f.condition.semanticEquals(GreaterThan(f.output(1), Literal(0))) => f + }.isDefined, "Predicate 'c2 > 0' should be pushed down to the CTE def 'v'.") + assert(distinctCteRepartitions(0).collectFirst { + case f: Filter if f.condition.find(_.semanticEquals(f.output(0))).isDefined => f + }.isDefined, "CTE 'vv' definition contains predicate 'c1 > 0'.") + assert(distinctCteRepartitions(1).collectFirst { + case f: Filter if f.condition.find(_.semanticEquals(f.output(0))).isDefined => f + }.isEmpty, "Predicate 'c1 > 0' should be not pushed down to the CTE def 'v'.") + // Check runtime repartition reuse. + assert( + collectWithSubqueries(df.queryExecution.executedPlan) { + case r: ReusedExchangeExec => r + }.length == 2, + "CTE repartition is reused.") + } + } + + test("CTE Predicate push-down and column pruning - combined predicate") { + withView("t") { + Seq((0, 1, 2), (1, 2, 3)).toDF("c1", "c2", "c3").createOrReplaceTempView("t") + val df = sql( + s"""with + |v as ( + | select c1, c2, c3, rand() c4 from t + |), + |vv as ( + | select v1.c1, v1.c2, rand() c5 from v v1, v v2 + | where v1.c1 > 0 and v2.c3 < 5 and v1.c2 = v2.c2 + |) + |select vv1.c1, vv1.c2, vv2.c1, vv2.c2 from vv vv1, vv vv2 + |where vv1.c2 > 0 and vv2.c2 > 0 and vv1.c1 = vv2.c1 + """.stripMargin) + checkAnswer(df, Row(1, 2, 1, 2) :: Nil) + assert( + df.queryExecution.analyzed.collect { + case WithCTE(_, cteDefs) => cteDefs + }.head.length == 2, + "With-CTE should contain 2 CTE defs after analysis.") + val cteRepartitions = df.queryExecution.optimizedPlan.collect { + case r: RepartitionOperation => r + } + assert(cteRepartitions.length == 6, + "CTE should not be inlined after optimization.") + val distinctCteRepartitions = cteRepartitions.map(_.canonicalized).distinct + // Check column pruning and predicate push-down. + assert(distinctCteRepartitions.length == 2) + assert(distinctCteRepartitions(1).collectFirst { + case p: Project if p.projectList.length == 3 => p + }.isDefined, "CTE columns should be pruned.") + assert( + distinctCteRepartitions(1).collectFirst { + case f: Filter + if f.condition.semanticEquals( + And( + GreaterThan(f.output(1), Literal(0)), + Or( + GreaterThan(f.output(0), Literal(0)), + LessThan(f.output(2), Literal(5))))) => + f + }.isDefined, + "Predicate 'c2 > 0 AND (c1 > 0 OR c3 < 5)' should be pushed down to the CTE def 'v'.") + // Check runtime repartition reuse. + assert( + collectWithSubqueries(df.queryExecution.executedPlan) { + case r: ReusedExchangeExec => r + }.length == 2, + "CTE repartition is reused.") + } + } + + test("Views with CTEs - 1 temp view") { + withView("t", "t2") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + sql( + s"""with + |v as ( + | select c1 + c2 c3 from t + |) + |select sum(c3) s from v + """.stripMargin).createOrReplaceTempView("t2") + val df = sql( + s"""with + |v as ( + | select c1 * c2 c3 from t + |) + |select sum(c3) from v except select s from t2 + """.stripMargin) + checkAnswer(df, Row(2) :: Nil) + } + } + + test("Views with CTEs - 2 temp views") { + withView("t", "t2", "t3") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + sql( + s"""with + |v as ( + | select c1 + c2 c3 from t + |) + |select sum(c3) s from v + """.stripMargin).createOrReplaceTempView("t2") + sql( + s"""with + |v as ( + | select c1 * c2 c3 from t + |) + |select sum(c3) s from v + """.stripMargin).createOrReplaceTempView("t3") + val df = sql("select s from t3 except select s from t2") + checkAnswer(df, Row(2) :: Nil) + } + } + + test("Views with CTEs - temp view + sql view") { + withTable("t") { + withView ("t2", "t3") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").write.saveAsTable("t") + sql( + s"""with + |v as ( + | select c1 + c2 c3 from t + |) + |select sum(c3) s from v + """.stripMargin).createOrReplaceTempView("t2") + sql( + s"""create view t3 as + |with + |v as ( + | select c1 * c2 c3 from t + |) + |select sum(c3) s from v + """.stripMargin) + val df = sql("select s from t3 except select s from t2") + checkAnswer(df, Row(2) :: Nil) + } + } + } + + test("Union of Dataframes with CTEs") { + val a = spark.sql("with t as (select 1 as n) select * from t ") + val b = spark.sql("with t as (select 2 as n) select * from t ") + val df = a.union(b) + checkAnswer(df, Row(1) :: Row(2) :: Nil) + } + + test("CTE definitions out of original order when not inlined") { + withView("t1", "t2") { + Seq((1, 2, 10, 100), (2, 3, 20, 200)).toDF("workspace_id", "issue_id", "shard_id", "field_id") + .createOrReplaceTempView("issue_current") + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.InlineCTE") { + val df = sql( + """ + |WITH cte_0 AS ( + | SELECT workspace_id, issue_id, shard_id, field_id FROM issue_current + |), + |cte_1 AS ( + | WITH filtered_source_table AS ( + | SELECT * FROM cte_0 WHERE shard_id in ( 10 ) + | ) + | SELECT source_table.workspace_id, field_id FROM cte_0 source_table + | INNER JOIN ( + | SELECT workspace_id, issue_id FROM filtered_source_table GROUP BY 1, 2 + | ) target_table + | ON source_table.issue_id = target_table.issue_id + | AND source_table.workspace_id = target_table.workspace_id + | WHERE source_table.shard_id IN ( 10 ) + |) + |SELECT * FROM cte_1 + """.stripMargin) + checkAnswer(df, Row(1, 100) :: Nil) + } + } + } } class CTEInlineSuiteAEOff extends CTEInlineSuiteBase with DisableAdaptiveExecutionSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0b00659f73b81..70b38db034f65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -593,6 +593,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |select * from q1 union all select * from q2""".stripMargin), Row(5, "5") :: Row(4, "4") :: Nil) + // inner CTE relation refers to outer CTE relation. + withSQLConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY.key -> "CORRECTED") { + checkAnswer( + sql( + """ + |with temp1 as (select 1 col), + |temp2 as ( + | with temp1 as (select col + 1 AS col from temp1), + | temp3 as (select col + 1 from temp1) + | select * from temp3 + |) + |select * from temp2 + |""".stripMargin), + Row(3)) + } } test("Allow only a single WITH clause per query") {