@@ -51,6 +51,14 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
5151 ! expressions.exists(! _.resolved) && childrenResolved && ! hasSpecialExpressions
5252 }
5353
54+ override def distinctSet : AttributeSet = {
55+ if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) {
56+ child.distinctSet
57+ } else {
58+ AttributeSet .empty
59+ }
60+ }
61+
5462 override def validConstraints : Set [Expression ] =
5563 child.constraints.union(getAliasedConstraints(projectList))
5664}
@@ -107,6 +115,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
107115
108116 override def maxRows : Option [Long ] = child.maxRows
109117
118+ override def distinctSet : AttributeSet = child.distinctSet
119+
110120 override protected def validConstraints : Set [Expression ] =
111121 child.constraints.union(splitConjunctivePredicates(condition).toSet)
112122}
@@ -137,6 +147,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
137147 leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
138148 }
139149
150+ override def distinctSet : AttributeSet = left.outputSet
151+
140152 override protected def validConstraints : Set [Expression ] =
141153 leftConstraints.union(rightConstraints)
142154
@@ -168,6 +180,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
168180 /** We don't use right.output because those rows get excluded from the set. */
169181 override def output : Seq [Attribute ] = left.output
170182
183+ override def distinctSet : AttributeSet = left.outputSet
184+
171185 override protected def validConstraints : Set [Expression ] = leftConstraints
172186
173187 override lazy val resolved : Boolean =
@@ -265,6 +279,9 @@ case class Join(
265279 }
266280 }
267281
282+ override def distinctSet : AttributeSet =
283+ if (joinType == LeftSemi ) left.distinctSet else AttributeSet .empty
284+
268285 override protected def validConstraints : Set [Expression ] = {
269286 joinType match {
270287 case Inner if condition.isDefined =>
@@ -312,6 +329,7 @@ case class Join(
312329 */
313330case class BroadcastHint (child : LogicalPlan ) extends UnaryNode {
314331 override def output : Seq [Attribute ] = child.output
332+ override def distinctSet : AttributeSet = child.distinctSet
315333
316334 // We manually set statistics of BroadcastHint to smallest value to make sure
317335 // the plan wrapped by BroadcastHint will be considered to broadcast later.
@@ -367,6 +385,7 @@ case class Sort(
367385 child : LogicalPlan ) extends UnaryNode {
368386 override def output : Seq [Attribute ] = child.output
369387 override def maxRows : Option [Long ] = child.maxRows
388+ override def distinctSet : AttributeSet = child.distinctSet
370389}
371390
372391/** Factory for constructing new `Range` nodes. */
@@ -422,6 +441,19 @@ case class Aggregate(
422441 override def output : Seq [Attribute ] = aggregateExpressions.map(_.toAttribute)
423442 override def maxRows : Option [Long ] = child.maxRows
424443
444+ override def distinctSet : AttributeSet = {
445+ if (isForDistinct) {
446+ AttributeSet (aggregateExpressions)
447+ } else if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) {
448+ child.distinctSet
449+ } else {
450+ AttributeSet .empty
451+ }
452+ }
453+
454+ def isForDistinct : Boolean =
455+ groupingExpressions.nonEmpty && groupingExpressions == aggregateExpressions
456+
425457 override def validConstraints : Set [Expression ] =
426458 child.constraints.union(getAliasedConstraints(aggregateExpressions))
427459
@@ -443,6 +475,8 @@ case class Window(
443475 override def output : Seq [Attribute ] =
444476 child.output ++ windowExpressions.map(_.toAttribute)
445477
478+ override def distinctSet : AttributeSet = child.distinctSet
479+
446480 def windowOutputSet : AttributeSet = AttributeSet (windowExpressions.map(_.toAttribute))
447481}
448482
@@ -585,6 +619,7 @@ object Limit {
585619
586620case class GlobalLimit (limitExpr : Expression , child : LogicalPlan ) extends UnaryNode {
587621 override def output : Seq [Attribute ] = child.output
622+ override def distinctSet : AttributeSet = child.distinctSet
588623 override def maxRows : Option [Long ] = {
589624 limitExpr match {
590625 case IntegerLiteral (limit) => Some (limit)
@@ -600,6 +635,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
600635
601636case class LocalLimit (limitExpr : Expression , child : LogicalPlan ) extends UnaryNode {
602637 override def output : Seq [Attribute ] = child.output
638+ override def distinctSet : AttributeSet = child.distinctSet
603639 override def maxRows : Option [Long ] = {
604640 limitExpr match {
605641 case IntegerLiteral (limit) => Some (limit)
@@ -615,6 +651,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
615651
616652case class SubqueryAlias (alias : String , child : LogicalPlan ) extends UnaryNode {
617653
654+ override def distinctSet : AttributeSet = child.distinctSet
618655 override def output : Seq [Attribute ] = child.output.map(_.withQualifier(Some (alias)))
619656}
620657
@@ -638,6 +675,7 @@ case class Sample(
638675 val isTableSample : java.lang.Boolean = false ) extends UnaryNode {
639676
640677 override def output : Seq [Attribute ] = child.output
678+ override def distinctSet : AttributeSet = child.distinctSet
641679
642680 override def statistics : Statistics = {
643681 val ratio = upperBound - lowerBound
@@ -658,6 +696,7 @@ case class Sample(
658696case class Distinct (child : LogicalPlan ) extends UnaryNode {
659697 override def maxRows : Option [Long ] = child.maxRows
660698 override def output : Seq [Attribute ] = child.output
699+ override def distinctSet : AttributeSet = child.outputSet
661700}
662701
663702/**
0 commit comments