@@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
2020import org .apache .spark .sql .catalyst .QueryPlanningTracker
2121import org .apache .spark .sql .catalyst .dsl .expressions ._
2222import org .apache .spark .sql .catalyst .dsl .plans ._
23- import org .apache .spark .sql .catalyst .expressions .{IsNull , ListQuery , Not }
24- import org .apache .spark .sql .catalyst .expressions .aggregate .AggregateExpression
23+ import org .apache .spark .sql .catalyst .expressions .{Cast , IsNull , ListQuery , Not }
2524import org .apache .spark .sql .catalyst .plans .{ExistenceJoin , LeftSemi , PlanTest }
26- import org .apache .spark .sql .catalyst .plans .logical .{Join , LocalRelation , LogicalPlan }
25+ import org .apache .spark .sql .catalyst .plans .logical .{LocalRelation , LogicalPlan }
2726import org .apache .spark .sql .catalyst .rules .RuleExecutor
27+ import org .apache .spark .sql .types .LongType
2828
2929
3030class RewriteSubquerySuite extends PlanTest {
@@ -84,10 +84,16 @@ class RewriteSubquerySuite extends PlanTest {
8484 test(" SPARK-50091: Don't put aggregate expression in join condition" ) {
8585 val relation1 = LocalRelation ($" c1" .int, $" c2" .int, $" c3" .int)
8686 val relation2 = LocalRelation ($" col1" .int, $" col2" .int, $" col3" .int)
87- val query = relation2.select(sum($" col2" ).in(ListQuery (relation1.select($" c3" ))))
88-
89- val optimized = Optimize .execute(query.analyze)
90- val join = optimized.find(_.isInstanceOf [Join ]).get.asInstanceOf [Join ]
91- assert(! join.condition.get.exists(_.isInstanceOf [AggregateExpression ]))
87+ val plan = relation2.groupBy()(sum($" col2" ).in(ListQuery (relation1.select($" c3" ))))
88+ val optimized = Optimize .execute(plan.analyze)
89+ val aggregate = relation2
90+ .select($" col2" )
91+ .groupBy()(sum($" col2" ).as(" _aggregateexpression" ))
92+ val correctAnswer = aggregate
93+ .join(relation1.select(Cast ($" c3" , LongType ).as(" c3" )),
94+ ExistenceJoin ($" exists" .boolean.withNullability(false )),
95+ Some ($" _aggregateexpression" === $" c3" ))
96+ .select($" exists" .as(" (sum(col2) IN (listquery()))" )).analyze
97+ comparePlans(optimized, correctAnswer)
9298 }
9399}
0 commit comments