From 37352f45a06a70ede6eaf10e8b57d97c0bcbf731 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 20 Mar 2016 15:01:34 -0700 Subject: [PATCH 1/5] eliminate Distinct --- .../spark/sql/catalyst/dsl/package.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 40 ++++++++++- .../optimizer/ReplaceOperatorSuite.scala | 68 +++++++++++++++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 8 +++ 4 files changed, 111 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index dc5264e2660d..846919c75ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -282,6 +282,8 @@ package object dsl { def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) + def distinct(): LogicalPlan = Distinct(logicalPlan) + def generate( generator: Generator, join: Boolean = false, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c419b5fd2204..088fb2ca0dd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -89,7 +89,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { PruneFilters, SimplifyCasts, SimplifyCaseConversionExpressions, - EliminateSerialization) :: + EliminateSerialization, + EliminateDistinct) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -1193,6 +1194,41 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] { } } +/** + * Removes useless Distinct that are not necessary. + */ +object EliminateDistinct extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Eliminate the useless distinct. + // Distinct has been replaced by Aggregate in the rule ReplaceDistinctWithAggregate + case a @ Aggregate(grouping, aggs, child) if isDistinct(a) && isDistinct(child) => child + } + + // propagate the distinct property from the child + @tailrec + private def isDistinct(plan: LogicalPlan): Boolean = plan match { + // Distinct(left) or Aggregate(left.output, left.output, _) always returns distinct results + case _: Distinct => true + case Aggregate(grouping, aggs, _) if grouping == aggs => true + // BinaryNode: + case p @ Join(_, _, LeftSemi, _) => isDistinct(p.left) + case p: Intersect => isDistinct(p.left) + case p: Except => isDistinct(p.left) + // UnaryNode: + case p: Project if p.child.outputSet.subsetOf(p.outputSet) => isDistinct(p.child) + case p: Aggregate if p.child.outputSet.subsetOf(p.outputSet) => isDistinct(p.child) + case p: Filter => isDistinct(p.child) + case p: GlobalLimit => isDistinct(p.child) + case p: LocalLimit => isDistinct(p.child) + case p: Sort => isDistinct(p.child) + case p: BroadcastHint => isDistinct(p.child) + case p: Sample => isDistinct(p.child) + case p: SubqueryAlias => isDistinct(p.child) + // Others: + case o => false + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. @@ -1291,7 +1327,7 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Intersect(left, right) => assert(left.output.size == right.output.size) - val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } + val joinCond = left.output.zip(right.output).map(EqualNullSafe.tupled) Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index f8ae5d9be208..7fe5575ab236 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -28,8 +28,9 @@ class ReplaceOperatorSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Replace Operators", FixedPoint(100), + ReplaceIntersectWithSemiJoin, ReplaceDistinctWithAggregate, - ReplaceIntersectWithSemiJoin) :: Nil + EliminateDistinct) :: Nil } test("replace Intersect with Left-semi Join") { @@ -40,8 +41,65 @@ class ReplaceOperatorSuite extends PlanTest { val optimized = Optimize.execute(query.analyze) val correctAnswer = - Aggregate(table1.output, table1.output, - Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze + table1.join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).groupBy('a, 'b)('a, 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Intersect with Left-semi Join whose left is Distinct") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + + val query = table1.distinct().intersect(table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + table1.groupBy('a, 'b)('a, 'b).join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("continuous Intersect whose children containing Distinct") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + val table3 = LocalRelation('e.int, 'f.int) + + // DISTINCT (actually, it is AGGREGATE) is the direct child + val query1 = table1.distinct().intersect(table2).intersect(table3) + val correctAnswer1 = + table1.groupBy('a, 'b)('a, 'b) + .join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)) + .join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze + comparePlans(Optimize.execute(query1.analyze), correctAnswer1) + } + + test("replace Intersect with Left-semi Join whose left is inferred to have distinct values") { + val table1 = LocalRelation('a.int) + val table2 = LocalRelation('c.int, 'd.int) + val table3 = LocalRelation('e.int, 'f.int) + + // DISTINCT is inferred from the child's child + val query2 = table1.distinct() + .where('a > 3).limit(5) + .select('a.attr, ('a + 1).as("b")).orderBy('a.asc, 'b.desc) + .intersect(table2).intersect(table3) + val correctAnswer2 = + table1.groupBy('a)('a).where('a > 3).limit(5) + .select('a.attr, ('a + 1).as("b")).orderBy('a.asc, 'b.desc) + .join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)) + .join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze + comparePlans(Optimize.execute(query2.analyze), correctAnswer2) + } + + test("replace Intersect with Left-semi Join whose left is the Distinct") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + + val query = table1.groupBy('a, 'b)('a, 'b).intersect(table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + table1.groupBy('a, 'b)('a, 'b).join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).analyze comparePlans(optimized, correctAnswer) } @@ -49,10 +107,10 @@ class ReplaceOperatorSuite extends PlanTest { test("replace Distinct with Aggregate") { val input = LocalRelation('a.int, 'b.int) - val query = Distinct(input) + val query = input.distinct() val optimized = Optimize.execute(query.analyze) - val correctAnswer = Aggregate(input.output, input.output, input) + val correctAnswer = input.groupBy('a, 'b)('a, 'b).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 199e138abfdc..803d085dc0b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -389,6 +389,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("id1", 1) :: Row("id", 1) :: Row("id1", 2) :: Nil) + + checkAnswer( + df.distinct().intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + + df.distinct().intersect(df).intersect(df).explain(true) } test("intersect - nullability") { From 96d9d4e0310f2edde0463e458db75962253f2771 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 20 Mar 2016 15:22:37 -0700 Subject: [PATCH 2/5] code clean. --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 803d085dc0b8..de38a3e9e779 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -395,8 +395,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("id1", 1) :: Row("id", 1) :: Row("id1", 2) :: Nil) - - df.distinct().intersect(df).intersect(df).explain(true) } test("intersect - nullability") { From dddc78be7a2cc05f1a832994ab94cc7d9b59e1d3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 20 Mar 2016 15:38:19 -0700 Subject: [PATCH 3/5] added one more test case. --- .../catalyst/optimizer/ReplaceOperatorSuite.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 7fe5575ab236..c9b8fb6f9d30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -73,6 +73,20 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(Optimize.execute(query1.analyze), correctAnswer1) } + test("continuous Intersect whose children do not contain Distinct") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + val table3 = LocalRelation('e.int, 'f.int) + + // Just need one Distinct for continuous Intersect, even if no child has Distinct. + val query1 = table1.intersect(table2).intersect(table3) + val correctAnswer1 = + table1 + .join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).groupBy('a, 'b)('a, 'b) + .join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze + comparePlans(Optimize.execute(query1.analyze), correctAnswer1) + } + test("replace Intersect with Left-semi Join whose left is inferred to have distinct values") { val table1 = LocalRelation('a.int) val table2 = LocalRelation('c.int, 'd.int) From bae2c8642d6bb46c6960b99c0e13ab7b807e038f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Mar 2016 00:17:16 -0700 Subject: [PATCH 4/5] fix R test case failure. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 33c7c02cf95c..e71e4f68b345 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1221,7 +1221,7 @@ object EliminateDistinct extends Rule[LogicalPlan] { private def isDistinct(plan: LogicalPlan): Boolean = plan match { // Distinct(left) or Aggregate(left.output, left.output, _) always returns distinct results case _: Distinct => true - case Aggregate(grouping, aggs, _) if grouping == aggs => true + case Aggregate(grouping, aggs, _) if grouping.nonEmpty && grouping == aggs => true // BinaryNode: case p @ Join(_, _, LeftSemi, _) => isDistinct(p.left) case p: Intersect => isDistinct(p.left) From 7d95bc17c2523fa25bcac59ed03886e8b0cb8c40 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 23 Mar 2016 21:56:59 -0700 Subject: [PATCH 5/5] create a distinctSet for uniqueness constraint --- .../sql/catalyst/optimizer/Optimizer.scala | 29 ++------------ .../spark/sql/catalyst/plans/QueryPlan.scala | 5 +++ .../plans/logical/basicOperators.scala | 39 +++++++++++++++++++ 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e71e4f68b345..5e02411d6b55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1213,31 +1213,10 @@ object EliminateDistinct extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Eliminate the useless distinct. // Distinct has been replaced by Aggregate in the rule ReplaceDistinctWithAggregate - case a @ Aggregate(grouping, aggs, child) if isDistinct(a) && isDistinct(child) => child - } - - // propagate the distinct property from the child - @tailrec - private def isDistinct(plan: LogicalPlan): Boolean = plan match { - // Distinct(left) or Aggregate(left.output, left.output, _) always returns distinct results - case _: Distinct => true - case Aggregate(grouping, aggs, _) if grouping.nonEmpty && grouping == aggs => true - // BinaryNode: - case p @ Join(_, _, LeftSemi, _) => isDistinct(p.left) - case p: Intersect => isDistinct(p.left) - case p: Except => isDistinct(p.left) - // UnaryNode: - case p: Project if p.child.outputSet.subsetOf(p.outputSet) => isDistinct(p.child) - case p: Aggregate if p.child.outputSet.subsetOf(p.outputSet) => isDistinct(p.child) - case p: Filter => isDistinct(p.child) - case p: GlobalLimit => isDistinct(p.child) - case p: LocalLimit => isDistinct(p.child) - case p: Sort => isDistinct(p.child) - case p: BroadcastHint => isDistinct(p.child) - case p: Sample => isDistinct(p.child) - case p: SubqueryAlias => isDistinct(p.child) - // Others: - case o => false + case a @ Aggregate(grouping, aggs, child) + if child.distinctSet.nonEmpty && child.distinctSet.subsetOf(AttributeSet(aggs)) && + a.isForDistinct => + child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index e9bfa09b7dff..4c6e98b5eafe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -101,6 +101,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def validConstraints: Set[Expression] = Set.empty + /** + * The set of attributes whose combination can uniquely identify a row. + */ + def distinctSet: AttributeSet = AttributeSet.empty + /** * Returns the set of attributes that are output by this node. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09c200fa839c..548bcb2ac0ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -51,6 +51,14 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } + override def distinctSet: AttributeSet = { + if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) { + child.distinctSet + } else { + AttributeSet.empty + } + } + override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) } @@ -107,6 +115,8 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows + override def distinctSet: AttributeSet = child.distinctSet + override protected def validConstraints: Set[Expression] = child.constraints.union(splitConjunctivePredicates(condition).toSet) } @@ -137,6 +147,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + override def distinctSet: AttributeSet = left.outputSet + override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) @@ -168,6 +180,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + override def distinctSet: AttributeSet = left.outputSet + override protected def validConstraints: Set[Expression] = leftConstraints override lazy val resolved: Boolean = @@ -265,6 +279,9 @@ case class Join( } } + override def distinctSet: AttributeSet = + if (joinType == LeftSemi) left.distinctSet else AttributeSet.empty + override protected def validConstraints: Set[Expression] = { joinType match { case Inner if condition.isDefined => @@ -312,6 +329,7 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.distinctSet // We manually set statistics of BroadcastHint to smallest value to make sure // the plan wrapped by BroadcastHint will be considered to broadcast later. @@ -367,6 +385,7 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows + override def distinctSet: AttributeSet = child.distinctSet } /** Factory for constructing new `Range` nodes. */ @@ -422,6 +441,19 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows + override def distinctSet: AttributeSet = { + if (isForDistinct) { + AttributeSet(aggregateExpressions) + } else if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) { + child.distinctSet + } else { + AttributeSet.empty + } + } + + def isForDistinct: Boolean = + groupingExpressions.nonEmpty && groupingExpressions == aggregateExpressions + override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(aggregateExpressions)) @@ -443,6 +475,8 @@ case class Window( override def output: Seq[Attribute] = child.output ++ windowExpressions.map(_.toAttribute) + override def distinctSet: AttributeSet = child.distinctSet + def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) } @@ -585,6 +619,7 @@ object Limit { case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.distinctSet override def maxRows: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) @@ -600,6 +635,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.distinctSet override def maxRows: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) @@ -615,6 +651,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode { + override def distinctSet: AttributeSet = child.distinctSet override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } @@ -638,6 +675,7 @@ case class Sample( val isTableSample: java.lang.Boolean = false) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.distinctSet override def statistics: Statistics = { val ratio = upperBound - lowerBound @@ -658,6 +696,7 @@ case class Sample( case class Distinct(child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.outputSet } /**