Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,11 @@ case class OneRowRelation() extends LeafNode {
override def computeStats(): Statistics = Statistics(sizeInBytes = 1)

/** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = OneRowRelation()
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
val newCopy = OneRowRelation()
newCopy.tags ++= this.tags
newCopy
}
}

/** A logical plan for `dropDuplicates`. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees

import java.util.UUID

import scala.collection.Map
import scala.collection.{mutable, Map}
import scala.reflect.ClassTag

import org.apache.commons.lang3.ClassUtils
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, StringConcat}
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -74,13 +74,23 @@ object CurrentOrigin {
}
}

// The name of the tree node tag. This is preferred over using string directly, as we can easily
// find all the defined tags.
case class TreeNodeTagName(name: String)

// scalastyle:off
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
// scalastyle:on
self: BaseType =>

val origin: Origin = CurrentOrigin.get

/**
* A mutable map for holding auxiliary information of this tree node. It will be carried over
* when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`.
*/
val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty

/**
* Returns a Seq of the children of this node.
* Children should not change. Immutability required for containsChild optimization
Expand Down Expand Up @@ -262,6 +272,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
if (this fastEquals afterRule) {
mapChildren(_.transformDown(rule))
} else {
// If the transform function replaces this node with a new one, carry over the tags.
afterRule.tags ++= this.tags
afterRule.mapChildren(_.transformDown(rule))
}
}
Expand All @@ -275,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = mapChildren(_.transformUp(rule))
if (this fastEquals afterRuleOnChildren) {
val newNode = if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here after applying the rule, we need to carry over the tags, if needed, too.

Since carrying the tags in both if and else branches, so maybe:

val newNode = if (this fastEquals afterRuleOnChildren) {
  CurrentOrigin.withOrigin(origin) {
    rule.applyOrElse(this, identity[BaseType])
  }
} else {
  CurrentOrigin.withOrigin(origin) {
    rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
  }
}

// carrying over the tags to newNode...

}
Expand All @@ -284,6 +296,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
}
}
// If the transform function replaces this node with a new one, carry over the tags.
newNode.tags ++= this.tags
newNode
}

/**
Expand Down Expand Up @@ -402,7 +417,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {

try {
CurrentOrigin.withOrigin(origin) {
defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
res.tags ++= this.tags
res
}
} catch {
case e: java.lang.IllegalArgumentException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,4 +620,55 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
assert(planString.startsWith("Truncated plan of"))
}
}

test("tags will be carried over after copy & transform") {
withClue("makeCopy") {
val node = Dummy(None)
node.tags += TreeNodeTagName("test") -> "a"
val copied = node.makeCopy(Array(Some(Literal(1))))
assert(copied.tags(TreeNodeTagName("test")) == "a")
}

def checkTransform(
sameTypeTransform: Expression => Expression,
differentTypeTransform: Expression => Expression): Unit = {
val child = Dummy(None)
child.tags += TreeNodeTagName("test") -> "child"
val node = Dummy(Some(child))
node.tags += TreeNodeTagName("test") -> "parent"

val transformed = sameTypeTransform(node)
// Both the child and parent keep the tags
assert(transformed.tags(TreeNodeTagName("test")) == "parent")
assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child")

val transformed2 = differentTypeTransform(node)
// Both the child and parent keep the tags, even if we transform the node to a new one of
// different type.
assert(transformed2.tags(TreeNodeTagName("test")) == "parent")
assert(transformed2.children.head.tags.contains(TreeNodeTagName("test")))
}

withClue("transformDown") {
checkTransform(
sameTypeTransform = _ transformDown {
case Dummy(None) => Dummy(Some(Literal(1)))
},
differentTypeTransform = _ transformDown {
case Dummy(None) => Literal(1)

This comment was marked as resolved.

This comment was marked as resolved.

})
}

withClue("transformUp") {
checkTransform(
sameTypeTransform = _ transformUp {
case Dummy(None) => Dummy(Some(Literal(1)))
},
differentTypeTransform = _ transformUp {
case Dummy(None) => Literal(1)

})
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext

import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.InternalCompilerException
Expand All @@ -35,9 +34,15 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTagName
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.ThreadUtils

object SparkPlan {
// a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag
// when converting a logical plan to a physical plan.
val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
}

/**
* The base class for physical operators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SparkPlanner =>

override def plan(plan: LogicalPlan): Iterator[SparkPlan] = {
super.plan(plan).map { p =>
val logicalPlan = plan match {
case ReturnAnswer(rootPlan) => rootPlan
case _ => plan
}
p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
p
}
}

/**
* Plans special cases of limit operators.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import scala.reflect.ClassTag

import org.apache.spark.sql.TPCDSQuerySuite
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.window.WindowExec

class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {

override protected def checkGeneratedCode(plan: SparkPlan): Unit = {
super.checkGeneratedCode(plan)
checkLogicalPlanTag(plan)
}

private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = {
// TODO: aggregate node without aggregate expressions can also be a final aggregate, but
// currently the aggregate node doesn't have a final/partial flag.
aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final)
}

// A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes.
private def isScanPlanTree(plan: SparkPlan): Boolean = plan match {
case p: ProjectExec => isScanPlanTree(p.child)
case f: FilterExec => isScanPlanTree(f.child)
case _: LeafExecNode => true
case _ => false
}

private def checkLogicalPlanTag(plan: SparkPlan): Unit = {
plan match {
case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: CartesianProductExec
| _: ShuffledHashJoinExec | _: SortMergeJoinExec =>
assertLogicalPlanType[Join](plan)

// There is no corresponding logical plan for the physical partial aggregate.
case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
assertLogicalPlanType[Aggregate](plan)
case agg: ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
assertLogicalPlanType[Aggregate](plan)
case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
assertLogicalPlanType[Aggregate](plan)

case _: WindowExec =>
assertLogicalPlanType[Window](plan)

case _: UnionExec =>
assertLogicalPlanType[Union](plan)

case _: SampleExec =>
assertLogicalPlanType[Sample](plan)

case _: GenerateExec =>
assertLogicalPlanType[Generate](plan)

// The exchange related nodes are created after the planning, they don't have corresponding
// logical plan.
case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add ReusedSubqueryExec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added below

assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))

// The subquery exec nodes are just wrappers of the actual nodes, they don't have
// corresponding logical plan.
case _: SubqueryExec | _: ReusedSubqueryExec =>
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))

case _ if isScanPlanTree(plan) =>
// The strategies for planning scan can remove or add FilterExec/ProjectExec nodes,
// so it's not simple to check. Instead, we only check that the origin LogicalPlan
// contains the corresponding leaf node of the SparkPlan.
// a strategy might remove the filter if it's totally pushed down, e.g.:
// logical = Project(Filter(Scan A))
// physical = ProjectExec(ScanExec A)
// we only check that leaf modes match between logical and physical plan.
val logicalLeaves = getLogicalPlan(plan).collectLeaves()
val physicalLeaves = plan.collectLeaves()
assert(logicalLeaves.length == 1)
assert(physicalLeaves.length == 1)
physicalLeaves.head match {
case _: RangeExec => logicalLeaves.head.isInstanceOf[Range]
case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation]
case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation]
case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation]
case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]]
case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation]
case _ =>
}
// Do not need to check the children recursively.
return

case _ =>
}

plan.children.foreach(checkLogicalPlanTag)
plan.subqueries.foreach(checkLogicalPlanTag)
}

private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME),
node.getClass.getSimpleName + " does not have a logical plan link")
node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
}

private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {
val logicalPlan = getLogicalPlan(node)
val expectedCls = implicitly[ClassTag[T]].runtimeClass
assert(expectedCls == logicalPlan.getClass)
}
}