Skip to content

Commit 5f686cc

Browse files
gatorsmilerxin
authored andcommitted
[SPARK-12656] [SQL] Implement Intersect with Left-semi Join
Our current Intersect physical operator simply delegates to RDD.intersect. We should remove the Intersect physical operator and simply transform a logical intersect into a semi-join with distinct. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins). After a search, I found one of the mainstream RDBMS did the same. In their query explain, Intersect is replaced by Left-semi Join. Left-semi Join could help outer-join elimination in Optimizer, as shown in the PR: #10566 Author: gatorsmile <[email protected]> Author: xiaoli <[email protected]> Author: Xiao Li <[email protected]> Closes #10630 from gatorsmile/IntersectBySemiJoin.
1 parent c5f745e commit 5f686cc

File tree

11 files changed

+211
-122
lines changed

11 files changed

+211
-122
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 62 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,63 @@ class Analyzer(
344344
}
345345
}
346346

347+
/**
348+
* Generate a new logical plan for the right child with different expression IDs
349+
* for all conflicting attributes.
350+
*/
351+
private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
352+
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
353+
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
354+
s"between $left and $right")
355+
356+
right.collect {
357+
// Handle base relations that might appear more than once.
358+
case oldVersion: MultiInstanceRelation
359+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
360+
val newVersion = oldVersion.newInstance()
361+
(oldVersion, newVersion)
362+
363+
// Handle projects that create conflicting aliases.
364+
case oldVersion @ Project(projectList, _)
365+
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
366+
(oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
367+
368+
case oldVersion @ Aggregate(_, aggregateExpressions, _)
369+
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
370+
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
371+
372+
case oldVersion: Generate
373+
if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
374+
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
375+
(oldVersion, oldVersion.copy(generatorOutput = newOutput))
376+
377+
case oldVersion @ Window(_, windowExpressions, _, _, child)
378+
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
379+
.nonEmpty =>
380+
(oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
381+
}
382+
// Only handle first case, others will be fixed on the next pass.
383+
.headOption match {
384+
case None =>
385+
/*
386+
* No result implies that there is a logical plan node that produces new references
387+
* that this rule cannot handle. When that is the case, there must be another rule
388+
* that resolves these conflicts. Otherwise, the analysis will fail.
389+
*/
390+
right
391+
case Some((oldRelation, newRelation)) =>
392+
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
393+
val newRight = right transformUp {
394+
case r if r == oldRelation => newRelation
395+
} transformUp {
396+
case other => other transformExpressions {
397+
case a: Attribute => attributeRewrites.get(a).getOrElse(a)
398+
}
399+
}
400+
newRight
401+
}
402+
}
403+
347404
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
348405
case p: LogicalPlan if !p.childrenResolved => p
349406

@@ -388,57 +445,11 @@ class Analyzer(
388445
.map(_.asInstanceOf[NamedExpression])
389446
a.copy(aggregateExpressions = expanded)
390447

391-
// Special handling for cases when self-join introduce duplicate expression ids.
392-
case j @ Join(left, right, _, _) if !j.selfJoinResolved =>
393-
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
394-
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")
395-
396-
right.collect {
397-
// Handle base relations that might appear more than once.
398-
case oldVersion: MultiInstanceRelation
399-
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
400-
val newVersion = oldVersion.newInstance()
401-
(oldVersion, newVersion)
402-
403-
// Handle projects that create conflicting aliases.
404-
case oldVersion @ Project(projectList, _)
405-
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
406-
(oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
407-
408-
case oldVersion @ Aggregate(_, aggregateExpressions, _)
409-
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
410-
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
411-
412-
case oldVersion: Generate
413-
if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
414-
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
415-
(oldVersion, oldVersion.copy(generatorOutput = newOutput))
416-
417-
case oldVersion @ Window(_, windowExpressions, _, _, child)
418-
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
419-
.nonEmpty =>
420-
(oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
421-
}
422-
// Only handle first case, others will be fixed on the next pass.
423-
.headOption match {
424-
case None =>
425-
/*
426-
* No result implies that there is a logical plan node that produces new references
427-
* that this rule cannot handle. When that is the case, there must be another rule
428-
* that resolves these conflicts. Otherwise, the analysis will fail.
429-
*/
430-
j
431-
case Some((oldRelation, newRelation)) =>
432-
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
433-
val newRight = right transformUp {
434-
case r if r == oldRelation => newRelation
435-
} transformUp {
436-
case other => other transformExpressions {
437-
case a: Attribute => attributeRewrites.get(a).getOrElse(a)
438-
}
439-
}
440-
j.copy(right = newRight)
441-
}
448+
// To resolve duplicate expression IDs for Join and Intersect
449+
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
450+
j.copy(right = dedupRight(left, right))
451+
case i @ Intersect(left, right) if !i.duplicateResolved =>
452+
i.copy(right = dedupRight(left, right))
442453

443454
// When resolve `SortOrder`s in Sort based on child, don't report errors as
444455
// we still have chance to resolve it based on grandchild

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,24 @@ trait CheckAnalysis {
214214
s"""Only a single table generating function is allowed in a SELECT clause, found:
215215
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
216216

217-
// Special handling for cases when self-join introduce duplicate expression ids.
218-
case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
219-
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
217+
case j: Join if !j.duplicateResolved =>
218+
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
220219
failAnalysis(
221220
s"""
222221
|Failure when resolving conflicting references in Join:
223222
|$plan
224223
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
225224
|""".stripMargin)
226225

226+
case i: Intersect if !i.duplicateResolved =>
227+
val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet)
228+
failAnalysis(
229+
s"""
230+
|Failure when resolving conflicting references in Intersect:
231+
|$plan
232+
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
233+
|""".stripMargin)
234+
227235
case o if !o.resolved =>
228236
failAnalysis(
229237
s"unresolved operator ${operator.simpleString}")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
5252
// since the other rules might make two separate Unions operators adjacent.
5353
Batch("Union", Once,
5454
CombineUnions) ::
55+
Batch("Replace Operators", FixedPoint(100),
56+
ReplaceIntersectWithSemiJoin,
57+
ReplaceDistinctWithAggregate) ::
5558
Batch("Aggregate", FixedPoint(100),
56-
ReplaceDistinctWithAggregate,
5759
RemoveLiteralFromGroupExpressions) ::
5860
Batch("Operator Optimizations", FixedPoint(100),
5961
// Operator push down
@@ -124,18 +126,13 @@ object EliminateSerialization extends Rule[LogicalPlan] {
124126
}
125127

126128
/**
127-
* Pushes certain operations to both sides of a Union, Intersect or Except operator.
129+
* Pushes certain operations to both sides of a Union or Except operator.
128130
* Operations that are safe to pushdown are listed as follows.
129131
* Union:
130132
* Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
131133
* safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT,
132134
* we will not be able to pushdown Projections.
133135
*
134-
* Intersect:
135-
* It is not safe to pushdown Projections through it because we need to get the
136-
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
137-
* with deterministic condition.
138-
*
139136
* Except:
140137
* It is not safe to pushdown Projections through it because we need to get the
141138
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
@@ -153,7 +150,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
153150

154151
/**
155152
* Rewrites an expression so that it can be pushed to the right side of a
156-
* Union, Intersect or Except operator. This method relies on the fact that the output attributes
153+
* Union or Except operator. This method relies on the fact that the output attributes
157154
* of a union/intersect/except are always equal to the left child's output.
158155
*/
159156
private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = {
@@ -210,17 +207,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
210207
}
211208
Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))
212209

213-
// Push down filter through INTERSECT
214-
case Filter(condition, Intersect(left, right)) =>
215-
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
216-
val rewrites = buildRewrites(left, right)
217-
Filter(nondeterministic,
218-
Intersect(
219-
Filter(deterministic, left),
220-
Filter(pushToRight(deterministic, rewrites), right)
221-
)
222-
)
223-
224210
// Push down filter through EXCEPT
225211
case Filter(condition, Except(left, right)) =>
226212
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
@@ -1054,6 +1040,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
10541040
}
10551041
}
10561042

1043+
/**
1044+
* Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator.
1045+
* {{{
1046+
* SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2
1047+
* ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2
1048+
* }}}
1049+
*
1050+
* Note:
1051+
* 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL.
1052+
* 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated
1053+
* join conditions will be incorrect.
1054+
*/
1055+
object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
1056+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1057+
case Intersect(left, right) =>
1058+
assert(left.output.size == right.output.size)
1059+
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
1060+
Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
1061+
}
1062+
}
1063+
10571064
/**
10581065
* Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
10591066
* but only makes the grouping key bigger.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22+
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2425
import org.apache.spark.sql.catalyst.plans._
@@ -90,28 +91,38 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
9091
override def output: Seq[Attribute] = child.output
9192
}
9293

93-
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
94-
final override lazy val resolved: Boolean =
95-
childrenResolved &&
96-
left.output.length == right.output.length &&
97-
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
98-
}
94+
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode
9995

10096
private[sql] object SetOperation {
10197
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
10298
}
10399

104100
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
105101

102+
def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
103+
106104
override def output: Seq[Attribute] =
107105
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
108106
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
109107
}
108+
109+
// Intersect are only resolved if they don't introduce ambiguous expression ids,
110+
// since the Optimizer will convert Intersect to Join.
111+
override lazy val resolved: Boolean =
112+
childrenResolved &&
113+
left.output.length == right.output.length &&
114+
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } &&
115+
duplicateResolved
110116
}
111117

112118
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
113119
/** We don't use right.output because those rows get excluded from the set. */
114120
override def output: Seq[Attribute] = left.output
121+
122+
override lazy val resolved: Boolean =
123+
childrenResolved &&
124+
left.output.length == right.output.length &&
125+
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
115126
}
116127

117128
/** Factory for constructing new `Union` nodes. */
@@ -169,13 +180,13 @@ case class Join(
169180
}
170181
}
171182

172-
def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
183+
def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
173184

174185
// Joins are only resolved if they don't introduce ambiguous expression ids.
175186
override lazy val resolved: Boolean = {
176187
childrenResolved &&
177188
expressions.forall(_.resolved) &&
178-
selfJoinResolved &&
189+
duplicateResolved &&
179190
condition.forall(_.dataType == BooleanType)
180191
}
181192
}
@@ -249,7 +260,7 @@ case class Range(
249260
end: Long,
250261
step: Long,
251262
numSlices: Int,
252-
output: Seq[Attribute]) extends LeafNode {
263+
output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation {
253264
require(step != 0, "step cannot be 0")
254265
val numElements: BigInt = {
255266
val safeStart = BigInt(start)
@@ -262,6 +273,9 @@ case class Range(
262273
}
263274
}
264275

276+
override def newInstance(): Range =
277+
Range(start, end, step, numSlices, output.map(_.newInstance()))
278+
265279
override def statistics: Statistics = {
266280
val sizeInBytes = LongType.defaultSize * numElements
267281
Statistics( sizeInBytes = sizeInBytes )

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ class AnalysisSuite extends AnalysisTest {
154154
checkAnalysis(plan, expected)
155155
}
156156

157+
test("self intersect should resolve duplicate expression IDs") {
158+
val plan = testRelation.intersect(testRelation)
159+
assertAnalysisSuccess(plan)
160+
}
161+
157162
test("SPARK-8654: invalid CAST in NULL IN(...) expression") {
158163
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil,
159164
LocalRelation()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,9 @@ class AggregateOptimizeSuite extends PlanTest {
2828

2929
object Optimize extends RuleExecutor[LogicalPlan] {
3030
val batches = Batch("Aggregate", FixedPoint(100),
31-
ReplaceDistinctWithAggregate,
3231
RemoveLiteralFromGroupExpressions) :: Nil
3332
}
3433

35-
test("replace distinct with aggregate") {
36-
val input = LocalRelation('a.int, 'b.int)
37-
38-
val query = Distinct(input)
39-
val optimized = Optimize.execute(query.analyze)
40-
41-
val correctAnswer = Aggregate(input.output, input.output, input)
42-
43-
comparePlans(optimized, correctAnswer)
44-
}
45-
4634
test("remove literals in grouping expression") {
4735
val input = LocalRelation('a.int, 'b.int)
4836

0 commit comments

Comments
 (0)