@@ -87,7 +87,8 @@ case class Generate(
8787 }
8888}
8989
90- case class Filter (condition : Expression , child : LogicalPlan ) extends UnaryNode {
90+ case class Filter (condition : Expression , child : LogicalPlan )
91+ extends UnaryNode with PredicateHelper {
9192 override def output : Seq [Attribute ] = child.output
9293
9394 override protected def validConstraints : Set [Expression ] = {
@@ -179,29 +180,35 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
179180 Statistics (sizeInBytes = sizeInBytes)
180181 }
181182
182- def rewriteConstraints (
183- planA : LogicalPlan ,
184- planB : LogicalPlan ,
183+ /**
184+ * Maps the constraints containing a given (original) sequence of attributes to those with a
185+ * given (reference) sequence of attributes. Given the nature of union, we expect that the
186+ * mapping between the original and reference sequences are symmetric.
187+ */
188+ private def rewriteConstraints (
189+ reference : Seq [Attribute ],
190+ original : Seq [Attribute ],
185191 constraints : Set [Expression ]): Set [Expression ] = {
186- require(planA.output. size == planB.output .size)
187- val attributeRewrites = AttributeMap (planB.output. zip(planA.output ))
192+ require(reference. size == original .size)
193+ val attributeRewrites = AttributeMap (original. zip(reference ))
188194 constraints.map(_ transform {
189195 case a : Attribute => attributeRewrites(a)
190196 })
191197 }
192198
193199 override protected def validConstraints : Set [Expression ] = {
194200 children
195- .map(child => rewriteConstraints(children.head, child, child.constraints))
201+ .map(child => rewriteConstraints(children.head.output , child.output , child.constraints))
196202 .reduce(_ intersect _)
197203 }
198204}
199205
200206case class Join (
201- left : LogicalPlan ,
202- right : LogicalPlan ,
203- joinType : JoinType ,
204- condition : Option [Expression ]) extends BinaryNode {
207+ left : LogicalPlan ,
208+ right : LogicalPlan ,
209+ joinType : JoinType ,
210+ condition : Option [Expression ])
211+ extends BinaryNode with PredicateHelper {
205212
206213 override def output : Seq [Attribute ] = {
207214 joinType match {
@@ -226,12 +233,11 @@ case class Join(
226233 .union(splitConjunctivePredicates(condition.get).toSet)
227234 case LeftSemi if condition.isDefined =>
228235 left.constraints
229- .union(right.constraints)
230236 .union(splitConjunctivePredicates(condition.get).toSet)
231237 case Inner =>
232238 left.constraints.union(right.constraints)
233239 case LeftSemi =>
234- left.constraints.union(right.constraints)
240+ left.constraints
235241 case LeftOuter =>
236242 left.constraints
237243 case RightOuter =>
0 commit comments