1616 */
1717package org .apache .spark .sql .catalyst .expressions
1818
19- import scala .collection .mutable
19+ import java .util .IdentityHashMap
20+
21+ import scala .collection .JavaConverters ._
2022
2123import com .google .common .cache .{CacheBuilder , CacheLoader , LoadingCache }
2224import com .google .common .util .concurrent .{ExecutionError , UncheckedExecutionException }
@@ -32,6 +34,10 @@ import org.apache.spark.sql.types.DataType
3234 * intercepts expression evaluation and loads from the cache first.
3335 */
3436class SubExprEvaluationRuntime (cacheMaxEntries : Int ) {
37+ // The id assigned to `ExpressionProxy`. `SubExprEvaluationRuntime` will use assigned ids of
38+ // `ExpressionProxy` to decide the equality when loading from cache. `SubExprEvaluationRuntime`
39+ // won't be use by multi-threads so we don't need to consider concurrency here.
40+ private var proxyExpressionCurrentId = 0
3541
3642 private [sql] val cache : LoadingCache [ExpressionProxy , ResultProxy ] = CacheBuilder .newBuilder()
3743 .maximumSize(cacheMaxEntries)
@@ -68,8 +74,12 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
6874 */
6975 private def replaceWithProxy (
7076 expr : Expression ,
71- proxyMap : Map [Expression , ExpressionProxy ]): Expression = {
72- proxyMap.getOrElse(expr, expr.mapChildren(replaceWithProxy(_, proxyMap)))
77+ proxyMap : IdentityHashMap [Expression , ExpressionProxy ]): Expression = {
78+ if (proxyMap.containsKey(expr)) {
79+ proxyMap.get(expr)
80+ } else {
81+ expr.mapChildren(replaceWithProxy(_, proxyMap))
82+ }
7383 }
7484
7585 /**
@@ -80,19 +90,20 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
8090
8191 expressions.foreach(equivalentExpressions.addExprTree(_))
8292
83- val proxyMap = mutable. Map .empty [Expression , ExpressionProxy ]
93+ val proxyMap = new IdentityHashMap [Expression , ExpressionProxy ]
8494
8595 val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1 )
8696 commonExprs.foreach { e =>
8797 val expr = e.head
88- val proxy = ExpressionProxy (expr, this )
98+ val proxy = ExpressionProxy (expr, proxyExpressionCurrentId, this )
99+ proxyExpressionCurrentId += 1
89100
90- proxyMap ++= e.map(_ -> proxy).toMap
101+ proxyMap.putAll( e.map(_ -> proxy).toMap.asJava)
91102 }
92103
93104 // Only adding proxy if we find subexpressions.
94- if (proxyMap.nonEmpty ) {
95- expressions.map(replaceWithProxy(_, proxyMap.toMap ))
105+ if (! proxyMap.isEmpty ) {
106+ expressions.map(replaceWithProxy(_, proxyMap))
96107 } else {
97108 expressions
98109 }
@@ -105,6 +116,7 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
105116 */
106117case class ExpressionProxy (
107118 child : Expression ,
119+ id : Int ,
108120 runtime : SubExprEvaluationRuntime ) extends Expression {
109121
110122 final override def dataType : DataType = child.dataType
@@ -118,6 +130,13 @@ case class ExpressionProxy(
118130 def proxyEval (input : InternalRow = null ): Any = child.eval(input)
119131
120132 override def eval (input : InternalRow = null ): Any = runtime.getEval(this )
133+
134+ override def equals (obj : Any ): Boolean = obj match {
135+ case other : ExpressionProxy => this .id == other.id
136+ case _ => false
137+ }
138+
139+ override def hashCode (): Int = this .id.hashCode()
121140}
122141
123142/**
0 commit comments