Skip to content

Commit f15ef96

Browse files
committed
support all joins
1 parent 7fb2f9c commit f15ef96

File tree

2 files changed

+91
-25
lines changed

2 files changed

+91
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,13 @@ case class Join(
215215
}
216216
}
217217

218-
def extractNullabilityConstraintsFromJoinCondition(): Set[Expression] = {
218+
override def constraints: Set[Expression] = {
219219
var constraints = Set.empty[Expression]
220-
if (condition.isDefined) {
221-
splitConjunctivePredicates(condition.get).foreach {
220+
221+
// Currently we only propagate constraints if the condition consists of equality
222+
// and ranges. For all other cases, we return an empty set of constraints
223+
def extractIsNotNullConstraints(condition: Expression): Set[Expression] = {
224+
splitConjunctivePredicates(condition).foreach {
222225
case EqualTo(l, r) =>
223226
constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r)))
224227
case GreaterThan(l, r) =>
@@ -230,29 +233,32 @@ case class Join(
230233
case LessThanOrEqual(l, r) =>
231234
constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r)))
232235
}
236+
constraints
233237
}
234-
// Currently we only propagate constraints if the condition consists of equality
235-
// and ranges. For all other cases, we return an empty set of constraints
236-
constraints
237-
}
238238

239-
override def constraints: Set[Expression] = {
240-
joinType match {
241-
case Inner =>
239+
def extractIsNullConstraints(plan: LogicalPlan): Set[Expression] = {
240+
constraints = constraints.union(plan.output.map(IsNull).toSet)
241+
constraints
242+
}
243+
244+
constraints = joinType match {
245+
case Inner if condition.isDefined =>
242246
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
243-
.union(extractNullabilityConstraintsFromJoinCondition())
244-
case LeftSemi =>
245-
extractConstraintsFromChild(left)
246-
.union(extractNullabilityConstraintsFromJoinCondition())
247-
case LeftOuter =>
247+
.union(extractIsNotNullConstraints(condition.get))
248+
case LeftSemi if condition.isDefined =>
248249
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
250+
.union(extractIsNotNullConstraints(condition.get))
251+
case LeftOuter =>
252+
extractConstraintsFromChild(left).union(extractIsNullConstraints(right))
249253
case RightOuter =>
250-
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
254+
extractConstraintsFromChild(right).union(extractIsNullConstraints(left))
251255
case FullOuter =>
252-
extractConstraintsFromChild(left).union(extractConstraintsFromChild(right))
256+
extractIsNullConstraints(left).union(extractIsNullConstraints(right))
253257
case _ =>
254258
Set.empty
255259
}
260+
261+
constraints.filter(_.references.subsetOf(outputSet))
256262
}
257263

258264
def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,19 @@ class ConstraintPropagationSuite extends SparkFunSuite {
2929
private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
3030
tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get
3131

32+
private def verifyConstraints(a: Set[Expression], b: Set[Expression]): Unit = {
33+
assert(a.forall(i => b.map(_.semanticEquals(i)).reduce(_ || _)))
34+
assert(b.forall(i => a.map(_.semanticEquals(i)).reduce(_ || _)))
35+
}
36+
3237
test("propagating constraints in filter/project") {
3338
val tr = LocalRelation('a.int, 'b.string, 'c.int)
3439
assert(tr.analyze.constraints.isEmpty)
3540
assert(tr.select('a.attr).analyze.constraints.isEmpty)
36-
assert(tr.where('a.attr > 10).analyze.constraints == Set(resolveColumn(tr, "a") > 10))
3741
assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty)
38-
assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
39-
.analyze.constraints == Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100))
42+
verifyConstraints(tr.where('a.attr > 10).analyze.constraints, Set(resolveColumn(tr, "a") > 10))
43+
verifyConstraints(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
44+
.analyze.constraints, Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100))
4045
}
4146

4247
test("propagating constraints in union") {
@@ -45,21 +50,76 @@ class ConstraintPropagationSuite extends SparkFunSuite {
4550
val tr3 = LocalRelation('g.int, 'h.int, 'i.int)
4651
assert(tr1.where('a.attr > 10).unionAll(tr2.where('e.attr > 10)
4752
.unionAll(tr3.where('i.attr > 10))).analyze.constraints.isEmpty)
48-
assert(tr1.where('a.attr > 10).unionAll(tr2.where('d.attr > 10)
49-
.unionAll(tr3.where('g.attr > 10))).analyze.constraints == Set(resolveColumn(tr1, "a") > 10))
53+
verifyConstraints(tr1.where('a.attr > 10).unionAll(tr2.where('d.attr > 10)
54+
.unionAll(tr3.where('g.attr > 10))).analyze.constraints, Set(resolveColumn(tr1, "a") > 10))
5055
}
5156

5257
test("propagating constraints in intersect") {
5358
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
5459
val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
55-
assert(tr1.where('a.attr > 10).intersect(tr2.where('b.attr < 100)).analyze.constraints ==
56-
Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100))
60+
verifyConstraints(tr1.where('a.attr > 10).intersect(tr2.where('b.attr < 100))
61+
.analyze.constraints, Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100))
5762
}
5863

5964
test("propagating constraints in except") {
6065
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
6166
val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
62-
assert(tr1.where('a.attr > 10).except(tr2.where('b.attr < 100)).analyze.constraints ==
67+
verifyConstraints(tr1.where('a.attr > 10).except(tr2.where('b.attr < 100)).analyze.constraints,
6368
Set(resolveColumn(tr1, "a") > 10))
6469
}
70+
71+
test("propagating constraints in inner join") {
72+
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
73+
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
74+
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), Inner,
75+
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
76+
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
77+
tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
78+
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
79+
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
80+
}
81+
82+
test("propagating constraints in left-semi join") {
83+
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
84+
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
85+
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), LeftSemi,
86+
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
87+
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
88+
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
89+
}
90+
91+
test("propagating constraints in left-outer join") {
92+
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
93+
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
94+
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), LeftOuter,
95+
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
96+
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
97+
IsNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
98+
IsNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get),
99+
IsNull(tr2.resolveQuoted("e", caseInsensitiveResolution).get)))
100+
}
101+
102+
test("propagating constraints in right-outer join") {
103+
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
104+
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
105+
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), RightOuter,
106+
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
107+
Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
108+
IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
109+
IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get),
110+
IsNull(tr1.resolveQuoted("c", caseInsensitiveResolution).get)))
111+
}
112+
113+
test("propagating constraints in full-outer join") {
114+
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
115+
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
116+
verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), FullOuter,
117+
Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints,
118+
Set(IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
119+
IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get),
120+
IsNull(tr1.resolveQuoted("c", caseInsensitiveResolution).get),
121+
IsNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
122+
IsNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get),
123+
IsNull(tr2.resolveQuoted("e", caseInsensitiveResolution).get)))
124+
}
65125
}

0 commit comments

Comments
 (0)