Skip to content

Commit a170d34

Browse files
Davies Liudavies
authored andcommitted
[SPARK-12395] [SQL] fix resulting columns of outer join
For API DataFrame.join(right, usingColumns, joinType), if the joinType is right_outer or full_outer, the resulting join columns could be wrong (will be null). The order of columns had been changed to match that with MySQL and PostgreSQL [1]. This PR also fix the nullability of output for outer join. [1] http://www.postgresql.org/docs/9.2/static/queries-table-expressions.html Author: Davies Liu <[email protected]> Closes #10353 from davies/fix_join.
1 parent cd3d937 commit a170d34

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.analysis._
3636
import org.apache.spark.sql.catalyst.expressions._
3737
import org.apache.spark.sql.catalyst.expressions.aggregate._
3838
import org.apache.spark.sql.catalyst.plans.logical._
39-
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
39+
import org.apache.spark.sql.catalyst.plans._
4040
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
4141
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution}
4242
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
@@ -455,10 +455,8 @@ class DataFrame private[sql](
455455
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
456456
// by creating a new instance for one of the branch.
457457
val joined = sqlContext.executePlan(
458-
Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join]
458+
Join(logicalPlan, right.logicalPlan, JoinType(joinType), None)).analyzed.asInstanceOf[Join]
459459

460-
// Project only one of the join columns.
461-
val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col))
462460
val condition = usingColumns.map { col =>
463461
catalyst.expressions.EqualTo(
464462
withPlan(joined.left).resolve(col),
@@ -467,9 +465,26 @@ class DataFrame private[sql](
467465
catalyst.expressions.And(cond, eqTo)
468466
}
469467

468+
// Project only one of the join columns.
469+
val joinedCols = JoinType(joinType) match {
470+
case Inner | LeftOuter | LeftSemi =>
471+
usingColumns.map(col => withPlan(joined.left).resolve(col))
472+
case RightOuter =>
473+
usingColumns.map(col => withPlan(joined.right).resolve(col))
474+
case FullOuter =>
475+
usingColumns.map { col =>
476+
val leftCol = withPlan(joined.left).resolve(col)
477+
val rightCol = withPlan(joined.right).resolve(col)
478+
Alias(Coalesce(Seq(leftCol, rightCol)), col)()
479+
}
480+
}
481+
// The nullability of output of joined could be different than original column,
482+
// so we can only compare them by exprId
483+
val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil)
484+
val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId))
470485
withPlan {
471486
Project(
472-
joined.output.filterNot(joinedCols.contains(_)),
487+
resultCols,
473488
Join(
474489
joined.left,
475490
joined.right,

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,28 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
4343
}
4444

4545
test("join - join using multiple columns and specifying join type") {
46-
val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str")
47-
val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str")
46+
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str")
47+
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str")
48+
49+
checkAnswer(
50+
df.join(df2, Seq("int", "str"), "inner"),
51+
Row(1, "1", 2, 3) :: Nil)
4852

4953
checkAnswer(
5054
df.join(df2, Seq("int", "str"), "left"),
51-
Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil)
55+
Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Nil)
5256

5357
checkAnswer(
5458
df.join(df2, Seq("int", "str"), "right"),
55-
Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil)
59+
Row(1, "1", 2, 3) :: Row(5, "5", null, 6) :: Nil)
60+
61+
checkAnswer(
62+
df.join(df2, Seq("int", "str"), "outer"),
63+
Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Row(5, "5", null, 6) :: Nil)
64+
65+
checkAnswer(
66+
df.join(df2, Seq("int", "str"), "left_semi"),
67+
Row(1, "1", 2) :: Nil)
5668
}
5769

5870
test("join - join using self join") {

0 commit comments

Comments
 (0)