diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index dc5264e2660d..846919c75ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -282,6 +282,8 @@ package object dsl { def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) + def distinct(): LogicalPlan = Distinct(logicalPlan) + def generate( generator: Generator, join: Boolean = false, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 41e8dc0f4674..5e02411d6b55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -90,7 +90,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, - EliminateSerialization) :: + EliminateSerialization, + EliminateDistinct) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -1205,6 +1206,20 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] { } } +/** + * Removes useless Distinct that are not necessary. + */ +object EliminateDistinct extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Eliminate the useless distinct. + // Distinct has been replaced by Aggregate in the rule ReplaceDistinctWithAggregate + case a @ Aggregate(grouping, aggs, child) + if child.distinctSet.nonEmpty && child.distinctSet.subsetOf(AttributeSet(aggs)) && + a.isForDistinct => + child + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. @@ -1303,7 +1318,7 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Intersect(left, right) => assert(left.output.size == right.output.size) - val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } + val joinCond = left.output.zip(right.output).map(EqualNullSafe.tupled) Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index e9bfa09b7dff..4c6e98b5eafe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -101,6 +101,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def validConstraints: Set[Expression] = Set.empty + /** + * The set of attributes whose combination can uniquely identify a row. + */ + def distinctSet: AttributeSet = AttributeSet.empty + /** * Returns the set of attributes that are output by this node. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09c200fa839c..548bcb2ac0ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -51,6 +51,14 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } + override def distinctSet: AttributeSet = { + if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) { + child.distinctSet + } else { + AttributeSet.empty + } + } + override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) } @@ -107,6 +115,8 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows + override def distinctSet: AttributeSet = child.distinctSet + override protected def validConstraints: Set[Expression] = child.constraints.union(splitConjunctivePredicates(condition).toSet) } @@ -137,6 +147,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + override def distinctSet: AttributeSet = left.outputSet + override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) @@ -168,6 +180,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + override def distinctSet: AttributeSet = left.outputSet + override protected def validConstraints: Set[Expression] = leftConstraints override lazy val resolved: Boolean = @@ -265,6 +279,9 @@ case class Join( } } + override def distinctSet: AttributeSet = + if (joinType == LeftSemi) left.distinctSet else AttributeSet.empty + override protected def validConstraints: Set[Expression] = { joinType match { case Inner if condition.isDefined => @@ -312,6 +329,7 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.distinctSet // We manually set statistics of BroadcastHint to smallest value to make sure // the plan wrapped by BroadcastHint will be considered to broadcast later. @@ -367,6 +385,7 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows + override def distinctSet: AttributeSet = child.distinctSet } /** Factory for constructing new `Range` nodes. */ @@ -422,6 +441,19 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows + override def distinctSet: AttributeSet = { + if (isForDistinct) { + AttributeSet(aggregateExpressions) + } else if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) { + child.distinctSet + } else { + AttributeSet.empty + } + } + + def isForDistinct: Boolean = + groupingExpressions.nonEmpty && groupingExpressions == aggregateExpressions + override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(aggregateExpressions)) @@ -443,6 +475,8 @@ case class Window( override def output: Seq[Attribute] = child.output ++ windowExpressions.map(_.toAttribute) + override def distinctSet: AttributeSet = child.distinctSet + def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) } @@ -585,6 +619,7 @@ object Limit { case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.distinctSet override def maxRows: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) @@ -600,6 +635,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.distinctSet override def maxRows: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) @@ -615,6 +651,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode { + override def distinctSet: AttributeSet = child.distinctSet override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } @@ -638,6 +675,7 @@ case class Sample( val isTableSample: java.lang.Boolean = false) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.distinctSet override def statistics: Statistics = { val ratio = upperBound - lowerBound @@ -658,6 +696,7 @@ case class Sample( case class Distinct(child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + override def distinctSet: AttributeSet = child.outputSet } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index f8ae5d9be208..c9b8fb6f9d30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -28,8 +28,9 @@ class ReplaceOperatorSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Replace Operators", FixedPoint(100), + ReplaceIntersectWithSemiJoin, ReplaceDistinctWithAggregate, - ReplaceIntersectWithSemiJoin) :: Nil + EliminateDistinct) :: Nil } test("replace Intersect with Left-semi Join") { @@ -40,8 +41,79 @@ class ReplaceOperatorSuite extends PlanTest { val optimized = Optimize.execute(query.analyze) val correctAnswer = - Aggregate(table1.output, table1.output, - Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze + table1.join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).groupBy('a, 'b)('a, 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Intersect with Left-semi Join whose left is Distinct") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + + val query = table1.distinct().intersect(table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + table1.groupBy('a, 'b)('a, 'b).join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("continuous Intersect whose children containing Distinct") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + val table3 = LocalRelation('e.int, 'f.int) + + // DISTINCT (actually, it is AGGREGATE) is the direct child + val query1 = table1.distinct().intersect(table2).intersect(table3) + val correctAnswer1 = + table1.groupBy('a, 'b)('a, 'b) + .join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)) + .join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze + comparePlans(Optimize.execute(query1.analyze), correctAnswer1) + } + + test("continuous Intersect whose children do not contain Distinct") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + val table3 = LocalRelation('e.int, 'f.int) + + // Just need one Distinct for continuous Intersect, even if no child has Distinct. + val query1 = table1.intersect(table2).intersect(table3) + val correctAnswer1 = + table1 + .join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).groupBy('a, 'b)('a, 'b) + .join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze + comparePlans(Optimize.execute(query1.analyze), correctAnswer1) + } + + test("replace Intersect with Left-semi Join whose left is inferred to have distinct values") { + val table1 = LocalRelation('a.int) + val table2 = LocalRelation('c.int, 'd.int) + val table3 = LocalRelation('e.int, 'f.int) + + // DISTINCT is inferred from the child's child + val query2 = table1.distinct() + .where('a > 3).limit(5) + .select('a.attr, ('a + 1).as("b")).orderBy('a.asc, 'b.desc) + .intersect(table2).intersect(table3) + val correctAnswer2 = + table1.groupBy('a)('a).where('a > 3).limit(5) + .select('a.attr, ('a + 1).as("b")).orderBy('a.asc, 'b.desc) + .join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)) + .join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze + comparePlans(Optimize.execute(query2.analyze), correctAnswer2) + } + + test("replace Intersect with Left-semi Join whose left is the Distinct") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + + val query = table1.groupBy('a, 'b)('a, 'b).intersect(table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + table1.groupBy('a, 'b)('a, 'b).join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).analyze comparePlans(optimized, correctAnswer) } @@ -49,10 +121,10 @@ class ReplaceOperatorSuite extends PlanTest { test("replace Distinct with Aggregate") { val input = LocalRelation('a.int, 'b.int) - val query = Distinct(input) + val query = input.distinct() val optimized = Optimize.execute(query.analyze) - val correctAnswer = Aggregate(input.output, input.output, input) + val correctAnswer = input.groupBy('a, 'b)('a, 'b).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d03597ee5dca..edce22cef426 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -410,6 +410,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("id1", 1) :: Row("id", 1) :: Row("id1", 2) :: Nil) + + checkAnswer( + df.distinct().intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) } test("intersect - nullability") {