From d57ecc19ec3ac1fbf79513519cd4d6e5781aca45 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 16 May 2019 09:17:48 +0800 Subject: [PATCH 1/5] add a logical plan link in the physical plan --- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../plans/logical/basicLogicalOperators.scala | 6 +- .../spark/sql/catalyst/trees/TreeNode.scala | 33 ++++- .../sql/catalyst/trees/TreeNodeSuite.scala | 4 + .../spark/sql/execution/SparkStrategies.scala | 11 ++ .../LogicalPlanTagInSparkPlanSuite.scala | 128 ++++++++++++++++++ 6 files changed, 180 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a6968c117782..ee704fac281c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTagName} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -271,6 +271,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT } object QueryPlan extends PredicateHelper { + val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan") + /** * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 2b98132f188f..95368214d61f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1078,7 +1078,11 @@ case class OneRowRelation() extends LeafNode { override def computeStats(): Statistics = Statistics(sizeInBytes = 1) /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */ - override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = OneRowRelation() + override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = { + val newCopy = OneRowRelation() + newCopy.tags ++= this.tags + newCopy + } } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 84ca0666e4cb..4f3709334425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID -import scala.collection.Map +import scala.collection.{mutable, Map} import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} -import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, StringConcat} +import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -74,6 +74,10 @@ object CurrentOrigin { } } +// The name of the tree node tag. This is preferred over using string directly, as we can easily +// find all the defined tags. +case class TreeNodeTagName(name: String) + // scalastyle:off abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // scalastyle:on @@ -81,6 +85,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { val origin: Origin = CurrentOrigin.get + /** + * A mutable map for holding auxiliary information of this tree node. It will be carried over + * when this node is copied via `makeCopy`. If a user copies the tree node via other ways like the + * `copy` method, it's his responsibility to carry over the tags. + */ + val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty + /** * Returns a Seq of the children of this node. * Children should not change. Immutability required for containsChild optimization @@ -262,6 +273,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (this fastEquals afterRule) { mapChildren(_.transformDown(rule)) } else { + // If the transform function replaces this node with a new one of the same type, carry over + // the tags. + if (afterRule.getClass == this.getClass) { + afterRule.tags ++= this.tags + } + afterRule.mapChildren(_.transformDown(rule)) } } @@ -280,9 +297,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { rule.applyOrElse(this, identity[BaseType]) } } else { - CurrentOrigin.withOrigin(origin) { + val newNode = CurrentOrigin.withOrigin(origin) { rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) } + // If the transform function replaces this node with a new one of the same type, carry over + // the tags. + if (newNode.getClass == this.getClass) { + newNode.tags ++= this.tags + } + newNode } } @@ -402,7 +425,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { try { CurrentOrigin.withOrigin(origin) { - defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] + val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] + res.tags ++= this.tags + res } } catch { case e: java.lang.IllegalArgumentException => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e7ad04f4af78..bd4f7b8bce37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -620,4 +620,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(planString.startsWith("Truncated plan of")) } } + + test("tags will be carried over after copy") { + + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 831fc7363486..29c04a2f81bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -63,6 +63,17 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode { abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => + override def plan(plan: LogicalPlan): Iterator[SparkPlan] = { + super.plan(plan).map { p => + val logicalPlan = plan match { + case ReturnAnswer(rootPlan) => rootPlan + case _ => plan + } + p.tags += QueryPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan + p + } + } + /** * Plans special cases of limit operators. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala new file mode 100644 index 000000000000..306d9e9b5b37 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.reflect.ClassTag + +import org.apache.spark.sql.TPCDSQuerySuite +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.window.WindowExec + +class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { + + override protected def checkGeneratedCode(plan: SparkPlan): Unit = { + super.checkGeneratedCode(plan) + checkLogicalPlanTag(plan) + } + + private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = { + // TODO: aggregate node without aggregate expressions can also be a final aggregate, but + // currently the aggregate node doesn't have a final/partial flag. + aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final) + } + + // A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes. + private def isScanPlanTree(plan: SparkPlan): Boolean = plan match { + case p: ProjectExec => isScanPlanTree(p.child) + case f: FilterExec => isScanPlanTree(f.child) + case _: LeafExecNode => true + case _ => false + } + + private def checkLogicalPlanTag(plan: SparkPlan): Unit = { + plan match { + case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: CartesianProductExec + | _: ShuffledHashJoinExec | _: SortMergeJoinExec => + assertLogicalPlanType[Join](plan) + + // There is no corresponding logical plan for the physical partial aggregate. + case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + + case _: WindowExec => + assertLogicalPlanType[Window](plan) + + case _: UnionExec => + assertLogicalPlanType[Union](plan) + + case _: SampleExec => + assertLogicalPlanType[Sample](plan) + + case _: GenerateExec => + assertLogicalPlanType[Generate](plan) + + // The exchange related nodes are created after the planning, they don't have corresponding + // logical plan. + case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => + assert(!plan.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME)) + + case _ if isScanPlanTree(plan) => + // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, + // so it's not simple to check. Instead, we only check that the origin LogicalPlan + // contains the corresponding leaf node of the SparkPlan. + // a strategy might remove the filter if it's totally pushed down, e.g.: + // logical = Project(Filter(Scan A)) + // physical = ProjectExec(ScanExec A) + // we only check that leaf modes match between logical and physical plan. + val logicalLeaves = getLogicalPlan(plan).collectLeaves() + val physicalLeaves = plan.collectLeaves() + assert(logicalLeaves.length == 1) + assert(physicalLeaves.length == 1) + physicalLeaves.head match { + case _: RangeExec => logicalLeaves.head.isInstanceOf[Range] + case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation] + case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation] + case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation] + case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]] + case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation] + case _ => + } + // Do not need to check the children recursively. + return + + case _ => + } + + plan.children.foreach(checkLogicalPlanTag) + plan.subqueries.foreach(checkLogicalPlanTag) + } + + private def getLogicalPlan(node: SparkPlan): LogicalPlan = { + assert(node.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME), + node.getClass.getSimpleName + " does not have a logical plan link") + node.tags(QueryPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan] + } + + private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = { + val logicalPlan = getLogicalPlan(node) + val expectedCls = implicitly[ClassTag[T]].runtimeClass + assert(expectedCls == logicalPlan.getClass) + } +} From 9f377d7a30bbfc6d982321212a585451cc87bb17 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 May 2019 14:04:36 +0800 Subject: [PATCH 2/5] address comments --- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 + .../spark/sql/catalyst/trees/TreeNode.scala | 20 ++++---- .../sql/catalyst/trees/TreeNodeSuite.scala | 47 +++++++++++++++++++ .../LogicalPlanTagInSparkPlanSuite.scala | 5 ++ 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ee704fac281c..f12a665541a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -271,6 +271,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT } object QueryPlan extends PredicateHelper { + // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag + // when converting a logical plan to a physical plan. val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan") /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 4f3709334425..f1698d198b62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -87,8 +87,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * A mutable map for holding auxiliary information of this tree node. It will be carried over - * when this node is copied via `makeCopy`. If a user copies the tree node via other ways like the - * `copy` method, it's his responsibility to carry over the tags. + * when this node is copied via `makeCopy`. The tags will be kept after transforming, if + * the node is transformed to the same type. Otherwise, tags will be dropped. */ val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty @@ -292,21 +292,21 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { val afterRuleOnChildren = mapChildren(_.transformUp(rule)) - if (this fastEquals afterRuleOnChildren) { + val newNode = if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) } } else { - val newNode = CurrentOrigin.withOrigin(origin) { + CurrentOrigin.withOrigin(origin) { rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) } - // If the transform function replaces this node with a new one of the same type, carry over - // the tags. - if (newNode.getClass == this.getClass) { - newNode.tags ++= this.tags - } - newNode } + // If the transform function replaces this node with a new one of the same type, carry over + // the tags. + if (newNode.getClass == this.getClass) { + newNode.tags ++= this.tags + } + newNode } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index bd4f7b8bce37..530c78a8c34f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -622,6 +622,53 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } test("tags will be carried over after copy") { + withClue("makeCopy") { + val node = Dummy(None) + node.tags += TreeNodeTagName("test") -> "a" + val copied = node.makeCopy(Array(Some(Literal(1)))) + assert(copied.tags(TreeNodeTagName("test")) == "a") + } + + def checkTransform( + sameTypeTransform: Expression => Expression, + differentTypeTransform: Expression => Expression): Unit = { + val child = Dummy(None) + child.tags += TreeNodeTagName("test") -> "child" + val node = Dummy(Some(child)) + node.tags += TreeNodeTagName("test") -> "parent" + + val transformed = sameTypeTransform(node) + // Both the child and parent keep the tags + assert(transformed.tags(TreeNodeTagName("test")) == "parent") + assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child") + + val transformed2 = differentTypeTransform(node) + // The parent keeps the tag, but the child loses the tag because it's transformed to a + // different type of node. + assert(transformed2.tags(TreeNodeTagName("test")) == "parent") + assert(!transformed2.children.head.tags.contains(TreeNodeTagName("test"))) + } + + withClue("transformDown") { + checkTransform( + sameTypeTransform = _ transformDown { + case Dummy(None) => Dummy(Some(Literal(1))) + }, + differentTypeTransform = _ transformDown { + case Dummy(None) => Literal(1) + }) + } + + withClue("transformUp") { + checkTransform( + sameTypeTransform = _ transformUp { + case Dummy(None) => Dummy(Some(Literal(1))) + }, + differentTypeTransform = _ transformUp { + case Dummy(None) => Literal(1) + + }) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala index 306d9e9b5b37..10ca5ce5e4ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -83,6 +83,11 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => assert(!plan.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME)) + // The subquery exec nodes are just wrappers of the actual nodes, they don't have + // corresponding logical plan. + case _: SubqueryExec | _: ReusedSubqueryExec => + assert(!plan.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME)) + case _ if isScanPlanTree(plan) => // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, // so it's not simple to check. Instead, we only check that the origin LogicalPlan From dafad5ca546e819462ca93b1d448e342478a28ec Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 18 May 2019 00:24:40 +0800 Subject: [PATCH 3/5] always carry over the tags in transform --- .../spark/sql/catalyst/trees/TreeNode.scala | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f1698d198b62..a5705d0f3250 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -87,8 +87,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * A mutable map for holding auxiliary information of this tree node. It will be carried over - * when this node is copied via `makeCopy`. The tags will be kept after transforming, if - * the node is transformed to the same type. Otherwise, tags will be dropped. + * when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`. */ val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty @@ -273,12 +272,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (this fastEquals afterRule) { mapChildren(_.transformDown(rule)) } else { - // If the transform function replaces this node with a new one of the same type, carry over - // the tags. - if (afterRule.getClass == this.getClass) { - afterRule.tags ++= this.tags - } - + // If the transform function replaces this node with a new one, carry over the tags. + afterRule.tags ++= this.tags afterRule.mapChildren(_.transformDown(rule)) } } @@ -301,11 +296,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) } } - // If the transform function replaces this node with a new one of the same type, carry over - // the tags. - if (newNode.getClass == this.getClass) { - newNode.tags ++= this.tags - } + // If the transform function replaces this node with a new one, carry over the tags. + newNode.tags ++= this.tags newNode } From b033f55560c7f2fc26b898a21cad444d619fab92 Mon Sep 17 00:00:00 2001 From: Peng Bo Date: Mon, 20 May 2019 13:33:02 +0800 Subject: [PATCH 4/5] fix ut pb (#13) --- .../org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 530c78a8c34f..195314b50de5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -621,7 +621,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } } - test("tags will be carried over after copy") { + test("tags will be carried over after copy & transform") { withClue("makeCopy") { val node = Dummy(None) node.tags += TreeNodeTagName("test") -> "a" @@ -646,7 +646,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { // The parent keeps the tag, but the child loses the tag because it's transformed to a // different type of node. assert(transformed2.tags(TreeNodeTagName("test")) == "parent") - assert(!transformed2.children.head.tags.contains(TreeNodeTagName("test"))) + assert(transformed2.children.head.tags.contains(TreeNodeTagName("test"))) } withClue("transformDown") { From b380f1dc39f0e4c7701b1a46603b0b62d1f94f96 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 20 May 2019 13:38:24 +0800 Subject: [PATCH 5/5] update --- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 6 +----- .../apache/spark/sql/catalyst/trees/TreeNodeSuite.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/SparkPlan.scala | 9 +++++++-- .../org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../sql/execution/LogicalPlanTagInSparkPlanSuite.scala | 8 ++++---- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index f12a665541a9..a6968c117782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTagName} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -271,10 +271,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT } object QueryPlan extends PredicateHelper { - // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag - // when converting a logical plan to a physical plan. - val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan") - /** * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 195314b50de5..5cfa84d2305f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -643,8 +643,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child") val transformed2 = differentTypeTransform(node) - // The parent keeps the tag, but the child loses the tag because it's transformed to a - // different type of node. + // Both the child and parent keep the tags, even if we transform the node to a new one of + // different type. assert(transformed2.tags(TreeNodeTagName("test")) == "parent") assert(transformed2.children.head.tags.contains(TreeNodeTagName("test"))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a89ccca99d05..307a01a50e56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.InternalCompilerException @@ -35,9 +34,15 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTagName import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.DataType -import org.apache.spark.util.ThreadUtils + +object SparkPlan { + // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag + // when converting a logical plan to a physical plan. + val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan") +} /** * The base class for physical operators. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 29c04a2f81bb..c9db78b3ed27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -69,7 +69,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ReturnAnswer(rootPlan) => rootPlan case _ => plan } - p.tags += QueryPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan + p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan p } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala index 10ca5ce5e4ed..ca7ced5ef538 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -81,12 +81,12 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { // The exchange related nodes are created after the planning, they don't have corresponding // logical plan. case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => - assert(!plan.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME)) + assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) // The subquery exec nodes are just wrappers of the actual nodes, they don't have // corresponding logical plan. case _: SubqueryExec | _: ReusedSubqueryExec => - assert(!plan.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME)) + assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) case _ if isScanPlanTree(plan) => // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, @@ -120,9 +120,9 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { } private def getLogicalPlan(node: SparkPlan): LogicalPlan = { - assert(node.tags.contains(QueryPlan.LOGICAL_PLAN_TAG_NAME), + assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME), node.getClass.getSimpleName + " does not have a logical plan link") - node.tags(QueryPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan] + node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan] } private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {