Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
val input = children.flatMap(_.output)
productIterator.map {
// Children are checked using sameResult above.
case tn: TreeNode[_] if children contains tn => null
case tn: TreeNode[_] if containsChild(tn) => null
case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
case s: Option[_] => s.map {
case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {

val origin: Origin = CurrentOrigin.get

/** Returns a Seq of the children of this node */
/**
* Returns a Seq of the children of this node.
* Children should not change. Immutability required for containsChild optimization
*/
def children: Seq[BaseType]

lazy val containsChild: Set[TreeNode[_]] = children.toSet

/**
* Faster version of equality which short-circuits when two treeNodes are the same instance.
* We don't just override Object.equals, as doing so prevents the scala compiler from
Expand Down Expand Up @@ -147,7 +152,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
def mapChildren(f: BaseType => BaseType): this.type = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (newChild fastEquals arg) {
arg
Expand All @@ -173,7 +178,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
val newArgs = productIterator.map {
// Handle Seq[TreeNode] in TreeNode parameters.
case s: Seq[_] => s.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
Expand All @@ -185,7 +190,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case nonChild: AnyRef => nonChild
case null => null
}
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
Expand Down Expand Up @@ -238,15 +243,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if children contains arg =>
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
if (!(newChild fastEquals arg)) {
changed = true
Expand All @@ -257,7 +262,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
if (!(newChild fastEquals arg)) {
changed = true
Expand Down Expand Up @@ -295,15 +300,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if children contains arg =>
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
if (!(newChild fastEquals arg)) {
changed = true
Expand All @@ -314,7 +319,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
if (!(newChild fastEquals arg)) {
changed = true
Expand Down Expand Up @@ -383,7 +388,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {

/** Returns a string representing the arguments to this node, minus any children */
def argString: String = productIterator.flatMap {
case tn: TreeNode[_] if children contains tn => Nil
case tn: TreeNode[_] if containsChild(tn) => Nil
case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil
case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil
case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil
Expand Down