Skip to content

Commit fb562fb

Browse files
committed
Fix pattern with casts and add more test cases.
1 parent 00e957c commit fb562fb

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,20 @@ object ExtractNonNullableAttributes extends Logging with PredicateHelper {
203203
result.add(b)
204204
}
205205
}
206+
case BinaryComparison(Cast(a: Attribute, _), Cast(b: Attribute, _)) => {
207+
if (!e.isInstanceOf[EqualNullSafe]) {
208+
result.add(a)
209+
result.add(b)
210+
}
211+
}
206212
case BinaryComparison(a: Attribute, _) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a)
207213
case BinaryComparison(_, a: Attribute) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a)
214+
case BinaryComparison(Cast(a: Attribute, _), _) =>
215+
if (!e.isInstanceOf[EqualNullSafe]) result.add(a)
216+
case BinaryComparison(_, Cast(a: Attribute, _)) =>
217+
if (!e.isInstanceOf[EqualNullSafe]) result.add(a)
208218
case Not(child) => extract(child)
219+
case _ =>
209220
}
210221
predicates.foreach { extract(_) }
211222
result.toSet

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
2424
import org.apache.spark.sql.catalyst.rules._
2525
import org.apache.spark.sql.catalyst.dsl.plans._
2626
import org.apache.spark.sql.catalyst.dsl.expressions._
27+
import org.apache.spark.sql.types.DoubleType
2728

2829
class JoinFilterSuite extends PlanTest {
2930

@@ -64,6 +65,37 @@ class JoinFilterSuite extends PlanTest {
6465
comparePlans(optimized, correctAnswer)
6566
}
6667

68+
test("joins infer is NOT NULL one key") {
69+
val x = testRelation.subquery('x)
70+
val y = testRelation.subquery('y)
71+
72+
val originalQuery = x.join(y).
73+
where("x.b".attr + 1 === "y.b".attr)
74+
75+
val optimized = Optimize.execute(originalQuery.analyze)
76+
77+
val correctAnswer = x.join(
78+
Filter(IsNotNull("y.b".attr), y), Inner, Some("x.b".attr + 1 === "y.b".attr)).analyze
79+
80+
comparePlans(optimized, correctAnswer)
81+
}
82+
83+
test("joins infer is NOT NULL for cast") {
84+
val x = testRelation.subquery('x)
85+
val y = testRelation.subquery('y)
86+
87+
val originalQuery = x.join(y).
88+
where(Cast("x.b".attr, DoubleType) === "y.b".attr)
89+
90+
val optimized = Optimize.execute(originalQuery.analyze)
91+
92+
val correctAnswer =
93+
Filter(IsNotNull("x.b".attr), x).join(
94+
Filter(IsNotNull("y.b".attr), y), Inner,
95+
Some(Cast("x.b".attr, DoubleType) === Cast("y.b".attr, DoubleType))).analyze
96+
comparePlans(optimized, correctAnswer)
97+
}
98+
6799
test("joins infer is NOT NULL on join keys") {
68100
val x = testRelation.subquery('x)
69101
val y = testRelation.subquery('y)

0 commit comments

Comments
 (0)