@@ -29,14 +29,19 @@ class ConstraintPropagationSuite extends SparkFunSuite {
2929 private def resolveColumn (tr : LocalRelation , columnName : String ): Expression =
3030 tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get
3131
32+ private def verifyConstraints (a : Set [Expression ], b : Set [Expression ]): Unit = {
33+ assert(a.forall(i => b.map(_.semanticEquals(i)).reduce(_ || _)))
34+ assert(b.forall(i => a.map(_.semanticEquals(i)).reduce(_ || _)))
35+ }
36+
3237 test(" propagating constraints in filter/project" ) {
3338 val tr = LocalRelation (' a .int, ' b .string, ' c .int)
3439 assert(tr.analyze.constraints.isEmpty)
3540 assert(tr.select(' a .attr).analyze.constraints.isEmpty)
36- assert(tr.where(' a .attr > 10 ).analyze.constraints == Set (resolveColumn(tr, " a" ) > 10 ))
3741 assert(tr.where(' a .attr > 10 ).select(' c .attr, ' b .attr).analyze.constraints.isEmpty)
38- assert(tr.where(' a .attr > 10 ).select(' c .attr, ' a .attr).where(' c .attr < 100 )
39- .analyze.constraints == Set (resolveColumn(tr, " a" ) > 10 , resolveColumn(tr, " c" ) < 100 ))
42+ verifyConstraints(tr.where(' a .attr > 10 ).analyze.constraints, Set (resolveColumn(tr, " a" ) > 10 ))
43+ verifyConstraints(tr.where(' a .attr > 10 ).select(' c .attr, ' a .attr).where(' c .attr < 100 )
44+ .analyze.constraints, Set (resolveColumn(tr, " a" ) > 10 , resolveColumn(tr, " c" ) < 100 ))
4045 }
4146
4247 test(" propagating constraints in union" ) {
@@ -45,21 +50,76 @@ class ConstraintPropagationSuite extends SparkFunSuite {
4550 val tr3 = LocalRelation (' g .int, ' h .int, ' i .int)
4651 assert(tr1.where(' a .attr > 10 ).unionAll(tr2.where(' e .attr > 10 )
4752 .unionAll(tr3.where(' i .attr > 10 ))).analyze.constraints.isEmpty)
48- assert (tr1.where(' a .attr > 10 ).unionAll(tr2.where(' d .attr > 10 )
49- .unionAll(tr3.where(' g .attr > 10 ))).analyze.constraints == Set (resolveColumn(tr1, " a" ) > 10 ))
53+ verifyConstraints (tr1.where(' a .attr > 10 ).unionAll(tr2.where(' d .attr > 10 )
54+ .unionAll(tr3.where(' g .attr > 10 ))).analyze.constraints, Set (resolveColumn(tr1, " a" ) > 10 ))
5055 }
5156
5257 test(" propagating constraints in intersect" ) {
5358 val tr1 = LocalRelation (' a .int, ' b .int, ' c .int)
5459 val tr2 = LocalRelation (' a .int, ' b .int, ' c .int)
55- assert (tr1.where(' a .attr > 10 ).intersect(tr2.where(' b .attr < 100 )).analyze.constraints ==
56- Set (resolveColumn(tr1, " a" ) > 10 , resolveColumn(tr1, " b" ) < 100 ))
60+ verifyConstraints (tr1.where(' a .attr > 10 ).intersect(tr2.where(' b .attr < 100 ))
61+ .analyze.constraints, Set (resolveColumn(tr1, " a" ) > 10 , resolveColumn(tr1, " b" ) < 100 ))
5762 }
5863
5964 test(" propagating constraints in except" ) {
6065 val tr1 = LocalRelation (' a .int, ' b .int, ' c .int)
6166 val tr2 = LocalRelation (' a .int, ' b .int, ' c .int)
62- assert (tr1.where(' a .attr > 10 ).except(tr2.where(' b .attr < 100 )).analyze.constraints ==
67+ verifyConstraints (tr1.where(' a .attr > 10 ).except(tr2.where(' b .attr < 100 )).analyze.constraints,
6368 Set (resolveColumn(tr1, " a" ) > 10 ))
6469 }
70+
71+ test(" propagating constraints in inner join" ) {
72+ val tr1 = LocalRelation (' a .int, ' b .int, ' c .int).subquery(' tr1 )
73+ val tr2 = LocalRelation (' a .int, ' d .int, ' e .int).subquery(' tr2 )
74+ verifyConstraints(tr1.where(' a .attr > 10 ).join(tr2.where(' d .attr < 100 ), Inner ,
75+ Some (" tr1.a" .attr === " tr2.a" .attr)).analyze.constraints,
76+ Set (tr1.resolveQuoted(" a" , caseInsensitiveResolution).get > 10 ,
77+ tr2.resolveQuoted(" d" , caseInsensitiveResolution).get < 100 ,
78+ IsNotNull (tr2.resolveQuoted(" a" , caseInsensitiveResolution).get),
79+ IsNotNull (tr1.resolveQuoted(" a" , caseInsensitiveResolution).get)))
80+ }
81+
82+ test(" propagating constraints in left-semi join" ) {
83+ val tr1 = LocalRelation (' a .int, ' b .int, ' c .int).subquery(' tr1 )
84+ val tr2 = LocalRelation (' a .int, ' d .int, ' e .int).subquery(' tr2 )
85+ verifyConstraints(tr1.where(' a .attr > 10 ).join(tr2.where(' d .attr < 100 ), LeftSemi ,
86+ Some (" tr1.a" .attr === " tr2.a" .attr)).analyze.constraints,
87+ Set (tr1.resolveQuoted(" a" , caseInsensitiveResolution).get > 10 ,
88+ IsNotNull (tr1.resolveQuoted(" a" , caseInsensitiveResolution).get)))
89+ }
90+
91+ test(" propagating constraints in left-outer join" ) {
92+ val tr1 = LocalRelation (' a .int, ' b .int, ' c .int).subquery(' tr1 )
93+ val tr2 = LocalRelation (' a .int, ' d .int, ' e .int).subquery(' tr2 )
94+ verifyConstraints(tr1.where(' a .attr > 10 ).join(tr2.where(' d .attr < 100 ), LeftOuter ,
95+ Some (" tr1.a" .attr === " tr2.a" .attr)).analyze.constraints,
96+ Set (tr1.resolveQuoted(" a" , caseInsensitiveResolution).get > 10 ,
97+ IsNull (tr2.resolveQuoted(" a" , caseInsensitiveResolution).get),
98+ IsNull (tr2.resolveQuoted(" d" , caseInsensitiveResolution).get),
99+ IsNull (tr2.resolveQuoted(" e" , caseInsensitiveResolution).get)))
100+ }
101+
102+ test(" propagating constraints in right-outer join" ) {
103+ val tr1 = LocalRelation (' a .int, ' b .int, ' c .int).subquery(' tr1 )
104+ val tr2 = LocalRelation (' a .int, ' d .int, ' e .int).subquery(' tr2 )
105+ verifyConstraints(tr1.where(' a .attr > 10 ).join(tr2.where(' d .attr < 100 ), RightOuter ,
106+ Some (" tr1.a" .attr === " tr2.a" .attr)).analyze.constraints,
107+ Set (tr2.resolveQuoted(" d" , caseInsensitiveResolution).get < 100 ,
108+ IsNull (tr1.resolveQuoted(" a" , caseInsensitiveResolution).get),
109+ IsNull (tr1.resolveQuoted(" b" , caseInsensitiveResolution).get),
110+ IsNull (tr1.resolveQuoted(" c" , caseInsensitiveResolution).get)))
111+ }
112+
113+ test(" propagating constraints in full-outer join" ) {
114+ val tr1 = LocalRelation (' a .int, ' b .int, ' c .int).subquery(' tr1 )
115+ val tr2 = LocalRelation (' a .int, ' d .int, ' e .int).subquery(' tr2 )
116+ verifyConstraints(tr1.where(' a .attr > 10 ).join(tr2.where(' d .attr < 100 ), FullOuter ,
117+ Some (" tr1.a" .attr === " tr2.a" .attr)).analyze.constraints,
118+ Set (IsNull (tr1.resolveQuoted(" a" , caseInsensitiveResolution).get),
119+ IsNull (tr1.resolveQuoted(" b" , caseInsensitiveResolution).get),
120+ IsNull (tr1.resolveQuoted(" c" , caseInsensitiveResolution).get),
121+ IsNull (tr2.resolveQuoted(" a" , caseInsensitiveResolution).get),
122+ IsNull (tr2.resolveQuoted(" d" , caseInsensitiveResolution).get),
123+ IsNull (tr2.resolveQuoted(" e" , caseInsensitiveResolution).get)))
124+ }
65125}
0 commit comments