Skip to content

Commit cb8af0e

Browse files
committed
address comments
1 parent 88a52c2 commit cb8af0e

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,30 +1169,36 @@ class Analyzer(
11691169
object ResolveNaturalJoin extends Rule[LogicalPlan] {
11701170
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
11711171
// Should not skip unresolved nodes because natural join is always unresolved.
1172-
case j @ Join(left, right, NaturalJoin(joinType), condition) if j.partlyResolved =>
1172+
case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
1173+
// find common column names from both sides, should be treated like usingColumns
11731174
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
11741175
val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
11751176
val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
11761177
val joinPairs = leftKeys.zip(rightKeys)
1178+
// Add joinPairs to joinConditions
11771179
val newCondition = (condition ++ joinPairs.map {
11781180
case (l, r) => EqualTo(l, r)
11791181
}).reduceLeftOption(And)
1182+
// columns not in joinPairs
11801183
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
11811184
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
1185+
// we should only keep unique columns(depends on joinType) for joinCols
11821186
val projectList = joinType match {
11831187
case LeftOuter =>
11841188
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
11851189
case RightOuter =>
11861190
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
11871191
case FullOuter =>
1192+
// in full outer join, joinCols should be non-null if there is.
11881193
val joinedCols = joinPairs.map {
11891194
case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)()
11901195
}
11911196
joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++
11921197
rUniqueOutput.map(_.withNullability(true))
11931198
case _ =>
1194-
leftKeys ++ lUniqueOutput ++ rUniqueOutput
1199+
rightKeys ++ lUniqueOutput ++ rUniqueOutput
11951200
}
1201+
// use Project to trim unnecessary fields
11961202
Project(projectList, Join(left, right, joinType, newCondition))
11971203
}
11981204
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,6 @@ trait CheckAnalysis {
105105
s"filter expression '${f.condition.prettyString}' " +
106106
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
107107

108-
case j @ Join(_, _, NaturalJoin(_), _) =>
109-
failAnalysis(s"natural join not resolved.")
110-
111108
case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
112109
failAnalysis(
113110
s"join condition '${condition.prettyString}' " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,20 @@ case class Join(
171171

172172
def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
173173

174-
lazy val partlyResolved: Boolean = {
174+
// if not a natural join, it is resolved. if it is a natural join, we still need to
175+
// eliminate natural before we mark it resolved, but the node should be ready for
176+
// resolution only if everything else is resolved here.
177+
lazy val resolvedExceptNatural: Boolean = {
175178
childrenResolved &&
176179
expressions.forall(_.resolved) &&
177180
selfJoinResolved &&
178181
condition.forall(_.dataType == BooleanType)
179182
}
183+
180184
// Joins are only resolved if they don't introduce ambiguous expression ids.
181185
override lazy val resolved: Boolean = joinType match {
182186
case NaturalJoin(_) => false
183-
case _ => partlyResolved
187+
case _ => resolvedExceptNatural
184188
}
185189
}
186190

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,17 @@ class AnalysisSuite extends AnalysisTest {
256256
val a = testRelation2.output(0)
257257
val b = testRelation2.output(1)
258258
val c = testRelation2.output(2)
259+
val testRelation0 = LocalRelation(
260+
AttributeReference("a", StringType, nullable = false)(),
261+
AttributeReference("b", StringType, nullable = false)(),
262+
AttributeReference("c", StringType, nullable = false)())
263+
val tt1 = testRelation0.select('a, 'b)
264+
val tt2 = testRelation0.select('a, 'c)
265+
val aa = testRelation0.output(0)
266+
val bb = testRelation0.output(1)
267+
val cc = testRelation0.output(2)
268+
val truebb = testRelation0.output(1).withNullability(true)
269+
val truecc = testRelation0.output(2).withNullability(true)
259270

260271
val plan1 = t1.join(t2, NaturalJoin(Inner), None)
261272
val expected1 = testRelation2.select(a, b).join(
@@ -273,5 +284,23 @@ class AnalysisSuite extends AnalysisTest {
273284
val expected4 = testRelation2.select(a, b).join(testRelation2.select(
274285
a, c), FullOuter, Some(EqualTo(a, a))).select(Alias(Coalesce(Seq(a, a)), "a")(), b, c)
275286
checkAnalysis(plan4, expected4)
287+
288+
val plan5 = tt1.join(tt2, NaturalJoin(Inner), None)
289+
val expected5 = testRelation0.select(aa, bb).join(
290+
testRelation0.select(aa, cc), Inner, Some(EqualTo(aa, aa))).select(aa, bb, cc)
291+
checkAnalysis(plan5, expected5)
292+
val plan6 = tt1.join(tt2, NaturalJoin(LeftOuter), None)
293+
val expected6 = testRelation0.select(aa, bb).join(
294+
testRelation0.select(aa, cc), LeftOuter, Some(EqualTo(aa, aa))).select(aa, bb, truecc)
295+
checkAnalysis(plan6, expected6)
296+
val plan7 = tt1.join(tt2, NaturalJoin(RightOuter), None)
297+
val expected7 = testRelation0.select(aa, bb).join(
298+
testRelation0.select(aa, cc), RightOuter, Some(EqualTo(aa, aa))).select(aa, truebb, cc)
299+
checkAnalysis(plan7, expected7)
300+
val plan8 = tt1.join(tt2, NaturalJoin(FullOuter), None)
301+
val expected8 = testRelation0.select(aa, bb).join(
302+
testRelation0.select(aa, cc), FullOuter, Some(EqualTo(aa, aa))).select(
303+
Alias(Coalesce(Seq(aa, aa)), "a")(), truebb, truecc)
304+
checkAnalysis(plan8, expected8)
276305
}
277306
}

0 commit comments

Comments
 (0)