@@ -1239,35 +1239,38 @@ class Analyzer(
12391239 */
12401240 object ResolveNaturalJoin extends Rule [LogicalPlan ] {
12411241 override def apply (plan : LogicalPlan ): LogicalPlan = plan resolveOperators {
1242- // Should not skip unresolved nodes because natural join is always unresolved.
12431242 case j @ Join (left, right, NaturalJoin (joinType), condition) if j.resolvedExceptNatural =>
1244- // find common column names from both sides, should be treated like usingColumns
1243+ // find common column names from both sides
12451244 val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
12461245 val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
12471246 val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
12481247 val joinPairs = leftKeys.zip(rightKeys)
1248+
12491249 // Add joinPairs to joinConditions
12501250 val newCondition = (condition ++ joinPairs.map {
12511251 case (l, r) => EqualTo (l, r)
1252- }).reduceLeftOption(And )
1252+ }).reduceOption(And )
1253+
12531254 // columns not in joinPairs
12541255 val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
12551256 val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
1256- // we should only keep unique columns(depends on joinType) for joinCols
1257+
1258+ // the output list looks like: join keys, columns from left, columns from right
12571259 val projectList = joinType match {
12581260 case LeftOuter =>
12591261 leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true ))
12601262 case RightOuter =>
12611263 rightKeys ++ lUniqueOutput.map(_.withNullability(true )) ++ rUniqueOutput
12621264 case FullOuter =>
12631265 // in full outer join, joinCols should be non-null if there is.
1264- val joinedCols = joinPairs.map {
1265- case (l, r) => Alias (Coalesce (Seq (l, r)), l.name)()
1266- }
1267- joinedCols ++ lUniqueOutput.map(_.withNullability(true )) ++
1266+ val joinedCols = joinPairs.map { case (l, r) => Alias (Coalesce (Seq (l, r)), l.name)() }
1267+ joinedCols ++
1268+ lUniqueOutput.map(_.withNullability(true )) ++
12681269 rUniqueOutput.map(_.withNullability(true ))
1269- case _ =>
1270+ case Inner =>
12701271 rightKeys ++ lUniqueOutput ++ rUniqueOutput
1272+ case _ =>
1273+ sys.error(" Unsupported natural join type " + joinType)
12711274 }
12721275 // use Project to trim unnecessary fields
12731276 Project (projectList, Join (left, right, joinType, newCondition))
0 commit comments