@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
2828import org .apache .spark .sql .catalyst .expressions .SubqueryExpression
2929import org .apache .spark .sql .catalyst .plans .logical .{BROADCAST , Join , JoinStrategyHint , SHUFFLE_HASH }
3030import org .apache .spark .sql .catalyst .util .DateTimeConstants
31- import org .apache .spark .sql .execution .{RDDScanExec , ScalarSubquery , SparkPlan }
31+ import org .apache .spark .sql .execution .{ExecSubqueryExpression , RDDScanExec , SparkPlan }
3232import org .apache .spark .sql .execution .columnar ._
3333import org .apache .spark .sql .execution .exchange .ShuffleExchangeExec
3434import org .apache .spark .sql .functions ._
@@ -89,24 +89,19 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSparkSessi
8989 sum
9090 }
9191
92- private def getExpressionSubqueryInMemoryTables (plan : SparkPlan ): Int = {
93- var inMemoryTableNum = 0
94- plan.transformExpressions {
95- case sub : ScalarSubquery =>
96- inMemoryTableNum = inMemoryTableNum + getNumInMemoryTablesRecursively(sub.plan)
97- sub
98- case e => e
99- }
100- inMemoryTableNum
92+ private def getNumInMemoryTablesInSubquery (plan : SparkPlan ): Int = {
93+ plan.expressions.map(_.collect {
94+ case sub : ExecSubqueryExpression => getNumInMemoryTablesRecursively(sub.plan)
95+ }.sum).sum
10196 }
10297
10398 private def getNumInMemoryTablesRecursively (plan : SparkPlan ): Int = {
10499 plan.collect {
105100 case inMemoryTable @ InMemoryTableScanExec (_, _, relation) =>
106101 getNumInMemoryTablesRecursively(relation.cachedPlan) +
107- getExpressionSubqueryInMemoryTables (inMemoryTable) + 1
102+ getNumInMemoryTablesInSubquery (inMemoryTable) + 1
108103 case p =>
109- getExpressionSubqueryInMemoryTables (p)
104+ getNumInMemoryTablesInSubquery (p)
110105 }.sum
111106 }
112107
0 commit comments