Skip to content

Commit db115d6

Browse files
committed
For review comment.
1 parent 77168fe commit db115d6

File tree

2 files changed

+36
-16
lines changed

2 files changed

+36
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
*/
1717
package org.apache.spark.sql.catalyst.expressions
1818

19-
import scala.collection.mutable
19+
import java.util.IdentityHashMap
20+
21+
import scala.collection.JavaConverters._
2022

2123
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
2224
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
@@ -32,6 +34,10 @@ import org.apache.spark.sql.types.DataType
3234
* intercepts expression evaluation and loads from the cache first.
3335
*/
3436
class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
37+
// The id assigned to `ExpressionProxy`. `SubExprEvaluationRuntime` will use assigned ids of
38+
// `ExpressionProxy` to decide the equality when loading from cache. `SubExprEvaluationRuntime`
39+
// won't be use by multi-threads so we don't need to consider concurrency here.
40+
private var proxyExpressionCurrentId = 0
3541

3642
private[sql] val cache: LoadingCache[ExpressionProxy, ResultProxy] = CacheBuilder.newBuilder()
3743
.maximumSize(cacheMaxEntries)
@@ -68,8 +74,12 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
6874
*/
6975
private def replaceWithProxy(
7076
expr: Expression,
71-
proxyMap: Map[Expression, ExpressionProxy]): Expression = {
72-
proxyMap.getOrElse(expr, expr.mapChildren(replaceWithProxy(_, proxyMap)))
77+
proxyMap: IdentityHashMap[Expression, ExpressionProxy]): Expression = {
78+
if (proxyMap.containsKey(expr)) {
79+
proxyMap.get(expr)
80+
} else {
81+
expr.mapChildren(replaceWithProxy(_, proxyMap))
82+
}
7383
}
7484

7585
/**
@@ -80,19 +90,20 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
8090

8191
expressions.foreach(equivalentExpressions.addExprTree(_))
8292

83-
val proxyMap = mutable.Map.empty[Expression, ExpressionProxy]
93+
val proxyMap = new IdentityHashMap[Expression, ExpressionProxy]
8494

8595
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
8696
commonExprs.foreach { e =>
8797
val expr = e.head
88-
val proxy = ExpressionProxy(expr, this)
98+
val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this)
99+
proxyExpressionCurrentId += 1
89100

90-
proxyMap ++= e.map(_ -> proxy).toMap
101+
proxyMap.putAll(e.map(_ -> proxy).toMap.asJava)
91102
}
92103

93104
// Only adding proxy if we find subexpressions.
94-
if (proxyMap.nonEmpty) {
95-
expressions.map(replaceWithProxy(_, proxyMap.toMap))
105+
if (!proxyMap.isEmpty) {
106+
expressions.map(replaceWithProxy(_, proxyMap))
96107
} else {
97108
expressions
98109
}
@@ -105,6 +116,7 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
105116
*/
106117
case class ExpressionProxy(
107118
child: Expression,
119+
id: Int,
108120
runtime: SubExprEvaluationRuntime) extends Expression {
109121

110122
final override def dataType: DataType = child.dataType
@@ -118,6 +130,13 @@ case class ExpressionProxy(
118130
def proxyEval(input: InternalRow = null): Any = child.eval(input)
119131

120132
override def eval(input: InternalRow = null): Any = runtime.getEval(this)
133+
134+
override def equals(obj: Any): Boolean = obj match {
135+
case other: ExpressionProxy => this.id == other.id
136+
case _ => false
137+
}
138+
139+
override def hashCode(): Int = this.id.hashCode()
121140
}
122141

123142
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
package org.apache.spark.sql.catalyst.expressions
1818

1919
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.sql.types.IntegerType
2021

2122
class SubExprEvaluationRuntimeSuite extends SparkFunSuite {
2223

2324
test("Evaluate ExpressionProxy should create cached result") {
2425
val runtime = new SubExprEvaluationRuntime(1)
25-
val proxy = ExpressionProxy(Literal(1), runtime)
26+
val proxy = ExpressionProxy(Literal(1), 0, runtime)
2627
assert(runtime.cache.size() == 0)
2728
proxy.eval()
2829
assert(runtime.cache.size() == 1)
@@ -33,31 +34,31 @@ class SubExprEvaluationRuntimeSuite extends SparkFunSuite {
3334
val runtime = new SubExprEvaluationRuntime(2)
3435
assert(runtime.cache.size() == 0)
3536

36-
val proxy1 = ExpressionProxy(Literal(1), runtime)
37+
val proxy1 = ExpressionProxy(Literal(1), 0, runtime)
3738
proxy1.eval()
3839
assert(runtime.cache.size() == 1)
3940
assert(runtime.cache.get(proxy1) == ResultProxy(1))
4041

41-
val proxy2 = ExpressionProxy(Literal(2), runtime)
42+
val proxy2 = ExpressionProxy(Literal(2), 1, runtime)
4243
proxy2.eval()
4344
assert(runtime.cache.size() == 2)
4445
assert(runtime.cache.get(proxy2) == ResultProxy(2))
4546

46-
val proxy3 = ExpressionProxy(Literal(3), runtime)
47+
val proxy3 = ExpressionProxy(Literal(3), 2, runtime)
4748
proxy3.eval()
4849
assert(runtime.cache.size() == 2)
4950
assert(runtime.cache.get(proxy3) == ResultProxy(3))
5051
}
5152

5253
test("setInput should empty cached result") {
5354
val runtime = new SubExprEvaluationRuntime(2)
54-
val proxy1 = ExpressionProxy(Literal(1), runtime)
55+
val proxy1 = ExpressionProxy(Literal(1), 0, runtime)
5556
assert(runtime.cache.size() == 0)
5657
proxy1.eval()
5758
assert(runtime.cache.size() == 1)
5859
assert(runtime.cache.get(proxy1) == ResultProxy(1))
5960

60-
val proxy2 = ExpressionProxy(Literal(2), runtime)
61+
val proxy2 = ExpressionProxy(Literal(2), 1, runtime)
6162
proxy2.eval()
6263
assert(runtime.cache.size() == 2)
6364
assert(runtime.cache.get(proxy2) == ResultProxy(2))
@@ -83,8 +84,8 @@ class SubExprEvaluationRuntimeSuite extends SparkFunSuite {
8384
})
8485
// ( (one * two) * (one * two) )
8586
assert(proxys.size == 2)
86-
val expected = ExpressionProxy(mul2, runtime)
87-
assert(proxys.head == expected)
87+
val expected = ExpressionProxy(mul2, 0, runtime)
88+
assert(proxys.forall(_ == expected))
8889
}
8990

9091
test("ExpressionProxy won't be on non deterministic") {

0 commit comments

Comments
 (0)