Skip to content

Commit b5ee466

Browse files
committed
Make test more explicit
1 parent cb4066a commit b5ee466

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import org.apache.spark.sql.catalyst.QueryPlanningTracker
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
24-
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
23+
import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
2524
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
26-
import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan}
25+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
2726
import org.apache.spark.sql.catalyst.rules.RuleExecutor
27+
import org.apache.spark.sql.types.LongType
2828

2929

3030
class RewriteSubquerySuite extends PlanTest {
@@ -84,10 +84,16 @@ class RewriteSubquerySuite extends PlanTest {
8484
test("SPARK-50091: Don't put aggregate expression in join condition") {
8585
val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
8686
val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
87-
val query = relation2.select(sum($"col2").in(ListQuery(relation1.select($"c3"))))
88-
89-
val optimized = Optimize.execute(query.analyze)
90-
val join = optimized.find(_.isInstanceOf[Join]).get.asInstanceOf[Join]
91-
assert(!join.condition.get.exists(_.isInstanceOf[AggregateExpression]))
87+
val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
88+
val optimized = Optimize.execute(plan.analyze)
89+
val aggregate = relation2
90+
.select($"col2")
91+
.groupBy()(sum($"col2").as("_aggregateexpression"))
92+
val correctAnswer = aggregate
93+
.join(relation1.select(Cast($"c3", LongType).as("c3")),
94+
ExistenceJoin($"exists".boolean.withNullability(false)),
95+
Some($"_aggregateexpression" === $"c3"))
96+
.select($"exists".as("(sum(col2) IN (listquery()))")).analyze
97+
comparePlans(optimized, correctAnswer)
9298
}
9399
}

0 commit comments

Comments
 (0)