Skip to content

Commit dee801a

Browse files
committed
[SPARK-12828][SQL] Natural join follow-up
This is a small addendum to #10762 to make the code more robust again future changes. Author: Reynold Xin <[email protected]> Closes #11070 from rxin/SPARK-12828-natural-join.
1 parent d390871 commit dee801a

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,35 +1239,38 @@ class Analyzer(
12391239
*/
12401240
object ResolveNaturalJoin extends Rule[LogicalPlan] {
12411241
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
1242-
// Should not skip unresolved nodes because natural join is always unresolved.
12431242
case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
1244-
// find common column names from both sides, should be treated like usingColumns
1243+
// find common column names from both sides
12451244
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
12461245
val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
12471246
val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
12481247
val joinPairs = leftKeys.zip(rightKeys)
1248+
12491249
// Add joinPairs to joinConditions
12501250
val newCondition = (condition ++ joinPairs.map {
12511251
case (l, r) => EqualTo(l, r)
1252-
}).reduceLeftOption(And)
1252+
}).reduceOption(And)
1253+
12531254
// columns not in joinPairs
12541255
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
12551256
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
1256-
// we should only keep unique columns(depends on joinType) for joinCols
1257+
1258+
// the output list looks like: join keys, columns from left, columns from right
12571259
val projectList = joinType match {
12581260
case LeftOuter =>
12591261
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
12601262
case RightOuter =>
12611263
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
12621264
case FullOuter =>
12631265
// in full outer join, joinCols should be non-null if there is.
1264-
val joinedCols = joinPairs.map {
1265-
case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)()
1266-
}
1267-
joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++
1266+
val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
1267+
joinedCols ++
1268+
lUniqueOutput.map(_.withNullability(true)) ++
12681269
rUniqueOutput.map(_.withNullability(true))
1269-
case _ =>
1270+
case Inner =>
12701271
rightKeys ++ lUniqueOutput ++ rUniqueOutput
1272+
case _ =>
1273+
sys.error("Unsupported natural join type " + joinType)
12711274
}
12721275
// use Project to trim unnecessary fields
12731276
Project(projectList, Join(left, right, joinType, newCondition))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,7 @@ case object LeftSemi extends JoinType {
6262
}
6363

6464
case class NaturalJoin(tpe: JoinType) extends JoinType {
65+
require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
66+
"Unsupported natural join type " + tpe)
6567
override def sql: String = "NATURAL " + tpe.sql
6668
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
3030
lazy val aNotNull = a.notNull
3131
lazy val bNotNull = b.notNull
3232
lazy val cNotNull = c.notNull
33-
lazy val r1 = LocalRelation(a, b)
34-
lazy val r2 = LocalRelation(a, c)
33+
lazy val r1 = LocalRelation(b, a)
34+
lazy val r2 = LocalRelation(c, a)
3535
lazy val r3 = LocalRelation(aNotNull, bNotNull)
36-
lazy val r4 = LocalRelation(bNotNull, cNotNull)
36+
lazy val r4 = LocalRelation(cNotNull, bNotNull)
3737

3838
test("natural inner join") {
3939
val plan = r1.join(r2, NaturalJoin(Inner), None)

0 commit comments

Comments
 (0)