Skip to content

Commit 7d95bc1

Browse files
committed
create a distinctSet for uniqueness constraint
1 parent bae2c86 commit 7d95bc1

File tree

3 files changed

+48
-25
lines changed

3 files changed

+48
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,31 +1213,10 @@ object EliminateDistinct extends Rule[LogicalPlan] {
12131213
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
12141214
// Eliminate the useless distinct.
12151215
// Distinct has been replaced by Aggregate in the rule ReplaceDistinctWithAggregate
1216-
case a @ Aggregate(grouping, aggs, child) if isDistinct(a) && isDistinct(child) => child
1217-
}
1218-
1219-
// propagate the distinct property from the child
1220-
@tailrec
1221-
private def isDistinct(plan: LogicalPlan): Boolean = plan match {
1222-
// Distinct(left) or Aggregate(left.output, left.output, _) always returns distinct results
1223-
case _: Distinct => true
1224-
case Aggregate(grouping, aggs, _) if grouping.nonEmpty && grouping == aggs => true
1225-
// BinaryNode:
1226-
case p @ Join(_, _, LeftSemi, _) => isDistinct(p.left)
1227-
case p: Intersect => isDistinct(p.left)
1228-
case p: Except => isDistinct(p.left)
1229-
// UnaryNode:
1230-
case p: Project if p.child.outputSet.subsetOf(p.outputSet) => isDistinct(p.child)
1231-
case p: Aggregate if p.child.outputSet.subsetOf(p.outputSet) => isDistinct(p.child)
1232-
case p: Filter => isDistinct(p.child)
1233-
case p: GlobalLimit => isDistinct(p.child)
1234-
case p: LocalLimit => isDistinct(p.child)
1235-
case p: Sort => isDistinct(p.child)
1236-
case p: BroadcastHint => isDistinct(p.child)
1237-
case p: Sample => isDistinct(p.child)
1238-
case p: SubqueryAlias => isDistinct(p.child)
1239-
// Others:
1240-
case o => false
1216+
case a @ Aggregate(grouping, aggs, child)
1217+
if child.distinctSet.nonEmpty && child.distinctSet.subsetOf(AttributeSet(aggs)) &&
1218+
a.isForDistinct =>
1219+
child
12411220
}
12421221
}
12431222

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
101101
*/
102102
protected def validConstraints: Set[Expression] = Set.empty
103103

104+
/**
105+
* The set of attributes whose combination can uniquely identify a row.
106+
*/
107+
def distinctSet: AttributeSet = AttributeSet.empty
108+
104109
/**
105110
* Returns the set of attributes that are output by this node.
106111
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
5151
!expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
5252
}
5353

54+
override def distinctSet: AttributeSet = {
55+
if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) {
56+
child.distinctSet
57+
} else {
58+
AttributeSet.empty
59+
}
60+
}
61+
5462
override def validConstraints: Set[Expression] =
5563
child.constraints.union(getAliasedConstraints(projectList))
5664
}
@@ -107,6 +115,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
107115

108116
override def maxRows: Option[Long] = child.maxRows
109117

118+
override def distinctSet: AttributeSet = child.distinctSet
119+
110120
override protected def validConstraints: Set[Expression] =
111121
child.constraints.union(splitConjunctivePredicates(condition).toSet)
112122
}
@@ -137,6 +147,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
137147
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
138148
}
139149

150+
override def distinctSet: AttributeSet = left.outputSet
151+
140152
override protected def validConstraints: Set[Expression] =
141153
leftConstraints.union(rightConstraints)
142154

@@ -168,6 +180,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
168180
/** We don't use right.output because those rows get excluded from the set. */
169181
override def output: Seq[Attribute] = left.output
170182

183+
override def distinctSet: AttributeSet = left.outputSet
184+
171185
override protected def validConstraints: Set[Expression] = leftConstraints
172186

173187
override lazy val resolved: Boolean =
@@ -265,6 +279,9 @@ case class Join(
265279
}
266280
}
267281

282+
override def distinctSet: AttributeSet =
283+
if (joinType == LeftSemi) left.distinctSet else AttributeSet.empty
284+
268285
override protected def validConstraints: Set[Expression] = {
269286
joinType match {
270287
case Inner if condition.isDefined =>
@@ -312,6 +329,7 @@ case class Join(
312329
*/
313330
case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
314331
override def output: Seq[Attribute] = child.output
332+
override def distinctSet: AttributeSet = child.distinctSet
315333

316334
// We manually set statistics of BroadcastHint to smallest value to make sure
317335
// the plan wrapped by BroadcastHint will be considered to broadcast later.
@@ -367,6 +385,7 @@ case class Sort(
367385
child: LogicalPlan) extends UnaryNode {
368386
override def output: Seq[Attribute] = child.output
369387
override def maxRows: Option[Long] = child.maxRows
388+
override def distinctSet: AttributeSet = child.distinctSet
370389
}
371390

372391
/** Factory for constructing new `Range` nodes. */
@@ -422,6 +441,19 @@ case class Aggregate(
422441
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
423442
override def maxRows: Option[Long] = child.maxRows
424443

444+
override def distinctSet: AttributeSet = {
445+
if (isForDistinct) {
446+
AttributeSet(aggregateExpressions)
447+
} else if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) {
448+
child.distinctSet
449+
} else {
450+
AttributeSet.empty
451+
}
452+
}
453+
454+
def isForDistinct: Boolean =
455+
groupingExpressions.nonEmpty && groupingExpressions == aggregateExpressions
456+
425457
override def validConstraints: Set[Expression] =
426458
child.constraints.union(getAliasedConstraints(aggregateExpressions))
427459

@@ -443,6 +475,8 @@ case class Window(
443475
override def output: Seq[Attribute] =
444476
child.output ++ windowExpressions.map(_.toAttribute)
445477

478+
override def distinctSet: AttributeSet = child.distinctSet
479+
446480
def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute))
447481
}
448482

@@ -585,6 +619,7 @@ object Limit {
585619

586620
case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
587621
override def output: Seq[Attribute] = child.output
622+
override def distinctSet: AttributeSet = child.distinctSet
588623
override def maxRows: Option[Long] = {
589624
limitExpr match {
590625
case IntegerLiteral(limit) => Some(limit)
@@ -600,6 +635,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
600635

601636
case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
602637
override def output: Seq[Attribute] = child.output
638+
override def distinctSet: AttributeSet = child.distinctSet
603639
override def maxRows: Option[Long] = {
604640
limitExpr match {
605641
case IntegerLiteral(limit) => Some(limit)
@@ -615,6 +651,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
615651

616652
case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode {
617653

654+
override def distinctSet: AttributeSet = child.distinctSet
618655
override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias)))
619656
}
620657

@@ -638,6 +675,7 @@ case class Sample(
638675
val isTableSample: java.lang.Boolean = false) extends UnaryNode {
639676

640677
override def output: Seq[Attribute] = child.output
678+
override def distinctSet: AttributeSet = child.distinctSet
641679

642680
override def statistics: Statistics = {
643681
val ratio = upperBound - lowerBound
@@ -658,6 +696,7 @@ case class Sample(
658696
case class Distinct(child: LogicalPlan) extends UnaryNode {
659697
override def maxRows: Option[Long] = child.maxRows
660698
override def output: Seq[Attribute] = child.output
699+
override def distinctSet: AttributeSet = child.outputSet
661700
}
662701

663702
/**

0 commit comments

Comments
 (0)