diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 6683f2dbfb392..4b8556b1bb5de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.trees -import java.util.UUID +import java.util.{IdentityHashMap, UUID} import scala.annotation.nowarn import scala.collection.{mutable, Map} @@ -841,7 +841,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] */ protected def stringArgs: Iterator[Any] = productIterator - private lazy val allChildren: Set[TreeNode[_]] = (children ++ innerChildren).toSet[TreeNode[_]] + private lazy val allChildren: IdentityHashMap[TreeNode[_], Any] = { + val set = new IdentityHashMap[TreeNode[_], Any]() + (children ++ innerChildren).foreach { + set.put(_, null) + } + set + } private def redactMapString[K, V](map: Map[K, V], maxFields: Int): List[String] = { // For security reason, redact the map value if the key is in certain patterns @@ -868,11 +874,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] /** Returns a string representing the arguments to this node, minus any children */ def argString(maxFields: Int): String = stringArgs.flatMap { - case tn: TreeNode[_] if allChildren.contains(tn) => Nil - case Some(tn: TreeNode[_]) if allChildren.contains(tn) => Nil + case tn: TreeNode[_] if allChildren.containsKey(tn) => Nil + case Some(tn: TreeNode[_]) if allChildren.containsKey(tn) => Nil case Some(tn: TreeNode[_]) => tn.simpleString(maxFields) :: Nil case tn: TreeNode[_] => tn.simpleString(maxFields) :: Nil - case seq: Seq[Any] if seq.toSet.subsetOf(allChildren.asInstanceOf[Set[Any]]) => Nil + case seq: Seq[Any] if seq.forall(allChildren.containsKey) => Nil case iter: Iterable[_] if iter.isEmpty => Nil case array: Array[_] if array.isEmpty => Nil case xs @ (_: Seq[_] | _: Set[_] | _: Array[_]) =>