From e0fc72a17b5536d02493402e7533d625bf82fb07 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 14 Jan 2016 00:31:00 -0800 Subject: [PATCH 01/17] add natural join support --- .../sql/catalyst/analysis/Analyzer.scala | 20 +++++++ .../plans/logical/basicOperators.scala | 56 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 20 +++++++ 3 files changed, 96 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 33d76eeb2128..670492f7bf20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -80,6 +80,7 @@ class Analyzer( ResolveAliases :: ResolveWindowOrder :: ResolveWindowFrame :: + ResolveNaturalJoin :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -1159,6 +1160,25 @@ class Analyzer( } } } + + /** + * Removes natural joins. + */ + object ResolveNaturalJoin extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. + + case j @ NaturalJoin(left, right, joinType, condition) => + val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) + val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) + val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) + val joinPairs = leftKeys.zip(rightKeys) + val newCondition = (condition ++ joinPairs.map { + case (l, r) => EqualTo(l, r) + }).reduceLeftOption(And) + Project(j.outerProjectList, Join(left, right, joinType, newCondition)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e9c970cd0808..41be2b433761 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -180,6 +180,62 @@ case class Join( } } +case class NaturalJoin( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + + override def output: Seq[Attribute] = { + val lUniqueOutput = left.output.filterNot(att => commonNames.contains(att.name)) + val rUniqueOutput = right.output.filterNot(att => commonNames.contains(att.name)) + joinType match { + case LeftOuter => + commonOutputFromLeft ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) + case RightOuter => + val commonOutputFromRight = + commonNames.map(cn => right.output.find(att => att.name == cn).get) + commonOutputFromRight ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput + case FullOuter => + // here use left as a place holder + commonOutputFromLeft ++ (lUniqueOutput ++ rUniqueOutput).map(_.withNullability(true)) + case _ => + commonOutputFromLeft ++ lUniqueOutput ++ rUniqueOutput + } + } + + @transient private val leftNames = left.output.map(_.name) + @transient private val rightNames = right.output.map(_.name) + @transient private val commonNames = leftNames.intersect(rightNames) + @transient + private val commonOutputFromLeft = left.output.filter(att => commonNames.contains(att.name)) + + def outerProjectList: Seq[NamedExpression] = { + if (joinType == FullOuter) { + val commonOutputFromRight = + commonNames.map(cn => right.output.find(att => att.name == cn).get) + val commonPairs = commonOutputFromLeft.zip(commonOutputFromRight) + val commonOutputExp = commonPairs.map { + case (l: Attribute, r: Attribute) => Alias(CaseWhen(Seq((IsNull(l), r)), l), l.name)() + } + commonOutputExp ++ output.takeRight(output.size - commonOutputExp.size) + } else { + output + } + } + + + def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + // Joins are only resolved if they don't introduce ambiguous expression ids. + override lazy val resolved: Boolean = { + childrenResolved && + expressions.forall(_.resolved) && + selfJoinResolved && + condition.forall(_.dataType == BooleanType) + } +} + /** * A hint for the optimizer that we should broadcast the `child` if used in a join operator. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 47308966e92c..9b40688b0234 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2056,4 +2056,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } } + + test("natural join") { + Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1").registerTempTable("nt1") + Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2").registerTempTable("nt2") + checkAnswer( + sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), + Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) + + checkAnswer( + sql("SELECT count(*) FROM nt1 natural full outer join nt2"), + Row(4) :: Nil) + } } From b5611d5e178af04a29d01b4a08043d9b29ed6a50 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 15 Jan 2016 00:22:45 -0800 Subject: [PATCH 02/17] add in new parser, and use the old Join node, and add test for Analyzer --- .../sql/catalyst/parser/FromClauseParser.g | 23 +++++--- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 2 + .../sql/catalyst/parser/SparkSqlParser.g | 4 ++ .../spark/sql/catalyst/CatalystQl.scala | 4 ++ .../sql/catalyst/analysis/Analyzer.scala | 7 ++- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 ++ .../spark/sql/catalyst/plans/joinTypes.scala | 4 ++ .../plans/logical/basicOperators.scala | 59 ++++++------------- .../sql/catalyst/analysis/AnalysisSuite.scala | 17 ++++++ .../sql/hive/execution/SQLQuerySuite.scala | 20 +++++++ 10 files changed, 90 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index 6d76afcd4ac0..e83f8a7cd1b5 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -117,15 +117,20 @@ joinToken @init { gParent.pushMsg("join type specifier", state); } @after { gParent.popMsg(state); } : - KW_JOIN -> TOK_JOIN - | KW_INNER KW_JOIN -> TOK_JOIN - | COMMA -> TOK_JOIN - | KW_CROSS KW_JOIN -> TOK_CROSSJOIN - | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN - | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN - | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN - | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN - | KW_ANTI KW_JOIN -> TOK_ANTIJOIN + KW_JOIN -> TOK_JOIN + | KW_INNER KW_JOIN -> TOK_JOIN + | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN + | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN + | COMMA -> TOK_JOIN + | KW_CROSS KW_JOIN -> TOK_CROSSJOIN + | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN + | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN + | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN + | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN + | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN + | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN + | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN + | KW_ANTI KW_JOIN -> TOK_ANTIJOIN ; lateralView diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index e4ffc634e8bf..e0178dcfa574 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -328,6 +328,8 @@ KW_WEEK: 'WEEK'|'WEEKS'; KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; +KW_NATURAL: 'NATURAL'; + // Operators // NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index c146ca591488..6f2afc79f48b 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -96,6 +96,10 @@ TOK_RIGHTOUTERJOIN; TOK_FULLOUTERJOIN; TOK_UNIQUEJOIN; TOK_CROSSJOIN; +TOK_NATURALJOIN; +TOK_NATURALLEFTOUTERJOIN; +TOK_NATURALRIGHTOUTERJOIN; +TOK_NATURALFULLOUTERJOIN; TOK_LOAD; TOK_EXPORT; TOK_IMPORT; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index f531d59a75cf..4ee04bb23abe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -495,6 +495,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_LEFTSEMIJOIN" => LeftSemi case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) + case "TOK_NATURALJOIN" => NaturalJoin(Inner) + case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter) + case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter) + case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter) } Join(nodeToRelation(relation1), nodeToRelation(relation2), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 670492f7bf20..0d262a838bd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -1160,7 +1161,7 @@ class Analyzer( } } } - + /** * Removes natural joins. */ @@ -1168,7 +1169,7 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. - case j @ NaturalJoin(left, right, joinType, condition) => + case j @ Join(left, right, NaturalJoin(joinType), condition) => val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) @@ -1176,7 +1177,7 @@ class Analyzer( val newCondition = (condition ++ joinPairs.map { case (l, r) => EqualTo(l, r) }).reduceLeftOption(And) - Project(j.outerProjectList, Join(left, right, joinType, newCondition)) + Project(j.outerProjectList(joinType), Join(left, right, joinType, newCondition)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f2e78d97442e..a33ad5f24dd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.NaturalJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -104,6 +105,9 @@ trait CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, NaturalJoin(_), _) => + failAnalysis(s"natural join not resolved.") + case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.prettyString}' " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index a5f6764aef7c..b10f1e63a73e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -60,3 +60,7 @@ case object FullOuter extends JoinType { case object LeftSemi extends JoinType { override def sql: String = "LEFT SEMI" } + +case class NaturalJoin(tpe: JoinType) extends JoinType { + override def sql: String = "NATURAL " + tpe.sql +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 41be2b433761..33399ef0eefc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -164,32 +164,21 @@ case class Join( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case NaturalJoin(jt) => + outerProjectList(jt).map(_.toAttribute) case _ => left.output ++ right.output } } - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - - // Joins are only resolved if they don't introduce ambiguous expression ids. - override lazy val resolved: Boolean = { - childrenResolved && - expressions.forall(_.resolved) && - selfJoinResolved && - condition.forall(_.dataType == BooleanType) - } -} - -case class NaturalJoin( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { - - override def output: Seq[Attribute] = { + def outerProjectList(jt: JoinType): Seq[NamedExpression] = { + val leftNames = left.output.map(_.name) + val rightNames = right.output.map(_.name) + val commonNames = leftNames.intersect(rightNames) + val commonOutputFromLeft = left.output.filter(att => commonNames.contains(att.name)) val lUniqueOutput = left.output.filterNot(att => commonNames.contains(att.name)) val rUniqueOutput = right.output.filterNot(att => commonNames.contains(att.name)) - joinType match { + jt match { case LeftOuter => commonOutputFromLeft ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) case RightOuter => @@ -197,34 +186,20 @@ case class NaturalJoin( commonNames.map(cn => right.output.find(att => att.name == cn).get) commonOutputFromRight ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput case FullOuter => - // here use left as a place holder - commonOutputFromLeft ++ (lUniqueOutput ++ rUniqueOutput).map(_.withNullability(true)) + val commonOutputFromRight = + commonNames.map(cn => right.output.find(att => att.name == cn).get) + val commonPairs = commonOutputFromLeft.zip(commonOutputFromRight) + val commonOutputExp = commonPairs.map { + case (l: Attribute, r: Attribute) => + Alias(CaseWhen(Seq((IsNull(l), r)), l), l.name)() + } + commonOutputExp ++ + lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput.map(_.withNullability(true)) case _ => commonOutputFromLeft ++ lUniqueOutput ++ rUniqueOutput } } - @transient private val leftNames = left.output.map(_.name) - @transient private val rightNames = right.output.map(_.name) - @transient private val commonNames = leftNames.intersect(rightNames) - @transient - private val commonOutputFromLeft = left.output.filter(att => commonNames.contains(att.name)) - - def outerProjectList: Seq[NamedExpression] = { - if (joinType == FullOuter) { - val commonOutputFromRight = - commonNames.map(cn => right.output.find(att => att.name == cn).get) - val commonPairs = commonOutputFromLeft.zip(commonOutputFromRight) - val commonOutputExp = commonPairs.map { - case (l: Attribute, r: Attribute) => Alias(CaseWhen(Seq((IsNull(l), r)), l), l.name)() - } - commonOutputExp ++ output.takeRight(output.size - commonOutputExp.size) - } else { - output - } - } - - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ab680282208c..a22dcc775ec8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -248,4 +249,20 @@ class AnalysisSuite extends AnalysisTest { val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val")) assertAnalysisSuccess(plan) } + + test("check resolve of natural join") { + val a = testRelation2.output(0) + val c = testRelation2.output(2) + + val t1 = testRelation2.select('a, 'b) + val t2 = testRelation2.select('a, 'c) + val plan1 = t1.join(t2, NaturalJoin(Inner), None) + assertAnalysisSuccess(plan1) + val plan2 = t1.join(t2, NaturalJoin(LeftOuter), None) + assertAnalysisSuccess(plan2) + val plan3 = t1.join(t2, NaturalJoin(RightOuter), None) + assertAnalysisSuccess(plan3) + val plan4 = t1.join(t2, NaturalJoin(FullOuter), None) + assertAnalysisSuccess(plan4) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 683008960aa2..de5c677eb164 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1463,4 +1463,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", BigDecimal("3.14"), "hello")) } + + test("SPARK-12828: natural join support") { + Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1").registerTempTable("nt1") + Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2").registerTempTable("nt2") + checkAnswer( + sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), + Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) + + checkAnswer( + sql("SELECT count(*) FROM nt1 natural full outer join nt2"), + Row(4) :: Nil) + } } From 20413824e140ca769da4b988765ab677a56656f4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 19 Jan 2016 23:38:21 -0800 Subject: [PATCH 03/17] fix compile --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 +++- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0d262a838bd6..87e0bacf31a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.NaturalJoin import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 44455b482074..f2323aaf4230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -919,6 +919,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) case FullOuter => f // DO Nothing for Full Outer Join + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") } // push down the join filter into sub query scanning if applicable @@ -953,6 +954,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { Join(newLeft, newRight, LeftOuter, newJoinCond) case FullOuter => f + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 33399ef0eefc..98e73af4972c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -191,7 +191,7 @@ case class Join( val commonPairs = commonOutputFromLeft.zip(commonOutputFromRight) val commonOutputExp = commonPairs.map { case (l: Attribute, r: Attribute) => - Alias(CaseWhen(Seq((IsNull(l), r)), l), l.name)() + Alias(Coalesce(Seq(l, r)), l.name)() } commonOutputExp ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput.map(_.withNullability(true)) From afb60a59fba17b07f5c744b4e961a2531eceef27 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 20 Jan 2016 01:03:15 -0800 Subject: [PATCH 04/17] fix df --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 518f9dcf94a7..dd29c18311ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -474,6 +474,7 @@ class DataFrame private[sql]( val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true) Alias(Coalesce(Seq(leftCol, rightCol)), col)() } + case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not supported.") } // The nullability of output of joined could be different than original column, // so we can only compare them by exprId From 674f0f7b2bea052534ad1353fd532bc1459edd9d Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 28 Jan 2016 02:46:43 -0800 Subject: [PATCH 05/17] address comments --- .../sql/catalyst/analysis/Analyzer.scala | 26 ++++++++--- .../plans/logical/basicOperators.scala | 43 +++---------------- .../sql/catalyst/analysis/AnalysisSuite.scala | 21 ++++++--- .../sql/hive/execution/SQLQuerySuite.scala | 20 --------- 4 files changed, 43 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 87e0bacf31a4..117b0edb9936 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.NaturalJoin +import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -1163,12 +1163,12 @@ class Analyzer( } /** - * Removes natural joins. + * Removes natural joins by calculating output columns based on output from two sides, + * Then apply a Project on a normal Join to eliminate natural join. */ object ResolveNaturalJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p if !p.resolved => p // Skip unresolved nodes. - + // Should not skip unresolved nodes because natural join is always unresolved. case j @ Join(left, right, NaturalJoin(joinType), condition) => val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) @@ -1177,7 +1177,23 @@ class Analyzer( val newCondition = (condition ++ joinPairs.map { case (l, r) => EqualTo(l, r) }).reduceLeftOption(And) - Project(j.outerProjectList(joinType), Join(left, right, joinType, newCondition)) + val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) + val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) + val projectList = joinType match { + case LeftOuter => + leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) + case RightOuter => + rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput + case FullOuter => + val joinedCols = joinPairs.map { + case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() + } + joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ + rUniqueOutput.map(_.withNullability(true)) + case _ => + leftKeys ++ lUniqueOutput ++ rUniqueOutput + } + Project(projectList, Join(left, right, joinType, newCondition)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 98e73af4972c..e656f07fc8e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -164,50 +164,21 @@ case class Join( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case NaturalJoin(jt) => - outerProjectList(jt).map(_.toAttribute) case _ => left.output ++ right.output } } - def outerProjectList(jt: JoinType): Seq[NamedExpression] = { - val leftNames = left.output.map(_.name) - val rightNames = right.output.map(_.name) - val commonNames = leftNames.intersect(rightNames) - val commonOutputFromLeft = left.output.filter(att => commonNames.contains(att.name)) - val lUniqueOutput = left.output.filterNot(att => commonNames.contains(att.name)) - val rUniqueOutput = right.output.filterNot(att => commonNames.contains(att.name)) - jt match { - case LeftOuter => - commonOutputFromLeft ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) - case RightOuter => - val commonOutputFromRight = - commonNames.map(cn => right.output.find(att => att.name == cn).get) - commonOutputFromRight ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput - case FullOuter => - val commonOutputFromRight = - commonNames.map(cn => right.output.find(att => att.name == cn).get) - val commonPairs = commonOutputFromLeft.zip(commonOutputFromRight) - val commonOutputExp = commonPairs.map { - case (l: Attribute, r: Attribute) => - Alias(Coalesce(Seq(l, r)), l.name)() - } - commonOutputExp ++ - lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput.map(_.withNullability(true)) - case _ => - commonOutputFromLeft ++ lUniqueOutput ++ rUniqueOutput - } - } - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. - override lazy val resolved: Boolean = { - childrenResolved && - expressions.forall(_.resolved) && - selfJoinResolved && - condition.forall(_.dataType == BooleanType) + override lazy val resolved: Boolean = joinType match { + case NaturalJoin(_) => false + case _ => + childrenResolved && + expressions.forall(_.resolved) && + selfJoinResolved && + condition.forall(_.dataType == BooleanType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a22dcc775ec8..2bd49d629f23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -251,18 +251,27 @@ class AnalysisSuite extends AnalysisTest { } test("check resolve of natural join") { + val t1 = testRelation2.select('a, 'b) + val t2 = testRelation2.select('a, 'c) val a = testRelation2.output(0) + val b = testRelation2.output(1) val c = testRelation2.output(2) - val t1 = testRelation2.select('a, 'b) - val t2 = testRelation2.select('a, 'c) val plan1 = t1.join(t2, NaturalJoin(Inner), None) - assertAnalysisSuccess(plan1) + val expected1 = testRelation2.select(a, b).join( + testRelation2.select(a, c), Inner, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan1, expected1) val plan2 = t1.join(t2, NaturalJoin(LeftOuter), None) - assertAnalysisSuccess(plan2) + val expected2 = testRelation2.select(a, b).join( + testRelation2.select(a, c), LeftOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan2, expected2) val plan3 = t1.join(t2, NaturalJoin(RightOuter), None) - assertAnalysisSuccess(plan3) + val expected3 = testRelation2.select(a, b).join( + testRelation2.select(a, c), RightOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan3, expected3) val plan4 = t1.join(t2, NaturalJoin(FullOuter), None) - assertAnalysisSuccess(plan4) + val expected4 = testRelation2.select(a, b).join(testRelation2.select( + a, c), FullOuter, Some(EqualTo(a, a))).select(Alias(Coalesce(Seq(a, a)), "a")(), b, c) + checkAnalysis(plan4, expected4) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index de5c677eb164..683008960aa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1463,24 +1463,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", BigDecimal("3.14"), "hello")) } - - test("SPARK-12828: natural join support") { - Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1").registerTempTable("nt1") - Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2").registerTempTable("nt2") - checkAnswer( - sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), - Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) - - checkAnswer( - sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), - Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) - - checkAnswer( - sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), - Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) - - checkAnswer( - sql("SELECT count(*) FROM nt1 natural full outer join nt2"), - Row(4) :: Nil) - } } From 572ed69a4754363ae48823fcddbc26ce097384c7 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 29 Jan 2016 01:52:04 +0800 Subject: [PATCH 06/17] Update Analyzer.scala --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 117b0edb9936..da52dcf30f96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ From 9fe467c1b1d5a1a2a700ecd92339ed2f16a73239 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 29 Jan 2016 01:55:31 +0800 Subject: [PATCH 07/17] Update basicOperators.scala --- .../sql/catalyst/plans/logical/basicOperators.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e656f07fc8e6..e47825c1f46f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -171,14 +171,16 @@ case class Join( def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + lazy val partlyResolved: Boolean = { + childrenResolved && + expressions.forall(_.resolved) && + selfJoinResolved && + condition.forall(_.dataType == BooleanType) + } // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = joinType match { case NaturalJoin(_) => false - case _ => - childrenResolved && - expressions.forall(_.resolved) && - selfJoinResolved && - condition.forall(_.dataType == BooleanType) + case _ => partlyResolved } } From 88a52c2381b0c90c2ffa5c76795a99e55cda20bb Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 29 Jan 2016 01:55:59 +0800 Subject: [PATCH 08/17] Update Analyzer.scala --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index da52dcf30f96..83067538e0dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1169,7 +1169,7 @@ class Analyzer( object ResolveNaturalJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Should not skip unresolved nodes because natural join is always unresolved. - case j @ Join(left, right, NaturalJoin(joinType), condition) => + case j @ Join(left, right, NaturalJoin(joinType), condition) if j.partlyResolved => val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) From cb8af0ec571c58f306e73187a01c98e4301e65c5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 28 Jan 2016 21:03:03 -0800 Subject: [PATCH 09/17] address comments --- .../sql/catalyst/analysis/Analyzer.scala | 10 +++++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 -- .../plans/logical/basicOperators.scala | 8 +++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 29 +++++++++++++++++++ 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 83067538e0dd..5b7ef59ca90f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1169,30 +1169,36 @@ class Analyzer( object ResolveNaturalJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Should not skip unresolved nodes because natural join is always unresolved. - case j @ Join(left, right, NaturalJoin(joinType), condition) if j.partlyResolved => + case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => + // find common column names from both sides, should be treated like usingColumns val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) val joinPairs = leftKeys.zip(rightKeys) + // Add joinPairs to joinConditions val newCondition = (condition ++ joinPairs.map { case (l, r) => EqualTo(l, r) }).reduceLeftOption(And) + // columns not in joinPairs val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) + // we should only keep unique columns(depends on joinType) for joinCols val projectList = joinType match { case LeftOuter => leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) case RightOuter => rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput case FullOuter => + // in full outer join, joinCols should be non-null if there is. val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput.map(_.withNullability(true)) case _ => - leftKeys ++ lUniqueOutput ++ rUniqueOutput + rightKeys ++ lUniqueOutput ++ rUniqueOutput } + // use Project to trim unnecessary fields Project(projectList, Join(left, right, joinType, newCondition)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a33ad5f24dd2..4ec9f1fe70a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -105,9 +105,6 @@ trait CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case j @ Join(_, _, NaturalJoin(_), _) => - failAnalysis(s"natural join not resolved.") - case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.prettyString}' " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e47825c1f46f..f193d91d74cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -171,16 +171,20 @@ case class Join( def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - lazy val partlyResolved: Boolean = { + // if not a natural join, it is resolved. if it is a natural join, we still need to + // eliminate natural before we mark it resolved, but the node should be ready for + // resolution only if everything else is resolved here. + lazy val resolvedExceptNatural: Boolean = { childrenResolved && expressions.forall(_.resolved) && selfJoinResolved && condition.forall(_.dataType == BooleanType) } + // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = joinType match { case NaturalJoin(_) => false - case _ => partlyResolved + case _ => resolvedExceptNatural } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 2bd49d629f23..b076ebc26fe9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -256,6 +256,17 @@ class AnalysisSuite extends AnalysisTest { val a = testRelation2.output(0) val b = testRelation2.output(1) val c = testRelation2.output(2) + val testRelation0 = LocalRelation( + AttributeReference("a", StringType, nullable = false)(), + AttributeReference("b", StringType, nullable = false)(), + AttributeReference("c", StringType, nullable = false)()) + val tt1 = testRelation0.select('a, 'b) + val tt2 = testRelation0.select('a, 'c) + val aa = testRelation0.output(0) + val bb = testRelation0.output(1) + val cc = testRelation0.output(2) + val truebb = testRelation0.output(1).withNullability(true) + val truecc = testRelation0.output(2).withNullability(true) val plan1 = t1.join(t2, NaturalJoin(Inner), None) val expected1 = testRelation2.select(a, b).join( @@ -273,5 +284,23 @@ class AnalysisSuite extends AnalysisTest { val expected4 = testRelation2.select(a, b).join(testRelation2.select( a, c), FullOuter, Some(EqualTo(a, a))).select(Alias(Coalesce(Seq(a, a)), "a")(), b, c) checkAnalysis(plan4, expected4) + + val plan5 = tt1.join(tt2, NaturalJoin(Inner), None) + val expected5 = testRelation0.select(aa, bb).join( + testRelation0.select(aa, cc), Inner, Some(EqualTo(aa, aa))).select(aa, bb, cc) + checkAnalysis(plan5, expected5) + val plan6 = tt1.join(tt2, NaturalJoin(LeftOuter), None) + val expected6 = testRelation0.select(aa, bb).join( + testRelation0.select(aa, cc), LeftOuter, Some(EqualTo(aa, aa))).select(aa, bb, truecc) + checkAnalysis(plan6, expected6) + val plan7 = tt1.join(tt2, NaturalJoin(RightOuter), None) + val expected7 = testRelation0.select(aa, bb).join( + testRelation0.select(aa, cc), RightOuter, Some(EqualTo(aa, aa))).select(aa, truebb, cc) + checkAnalysis(plan7, expected7) + val plan8 = tt1.join(tt2, NaturalJoin(FullOuter), None) + val expected8 = testRelation0.select(aa, bb).join( + testRelation0.select(aa, cc), FullOuter, Some(EqualTo(aa, aa))).select( + Alias(Coalesce(Seq(aa, aa)), "a")(), truebb, truecc) + checkAnalysis(plan8, expected8) } } From 7e7de890c71284babb9c34289c12d0e229c17e60 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 29 Jan 2016 01:18:31 -0800 Subject: [PATCH 10/17] in a separate suite --- .../sql/catalyst/analysis/AnalysisSuite.scala | 55 ---------- .../analysis/ResolveNaturalJoinSuite.scala | 103 ++++++++++++++++++ 2 files changed, 103 insertions(+), 55 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index b076ebc26fe9..ab680282208c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -249,58 +248,4 @@ class AnalysisSuite extends AnalysisTest { val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val")) assertAnalysisSuccess(plan) } - - test("check resolve of natural join") { - val t1 = testRelation2.select('a, 'b) - val t2 = testRelation2.select('a, 'c) - val a = testRelation2.output(0) - val b = testRelation2.output(1) - val c = testRelation2.output(2) - val testRelation0 = LocalRelation( - AttributeReference("a", StringType, nullable = false)(), - AttributeReference("b", StringType, nullable = false)(), - AttributeReference("c", StringType, nullable = false)()) - val tt1 = testRelation0.select('a, 'b) - val tt2 = testRelation0.select('a, 'c) - val aa = testRelation0.output(0) - val bb = testRelation0.output(1) - val cc = testRelation0.output(2) - val truebb = testRelation0.output(1).withNullability(true) - val truecc = testRelation0.output(2).withNullability(true) - - val plan1 = t1.join(t2, NaturalJoin(Inner), None) - val expected1 = testRelation2.select(a, b).join( - testRelation2.select(a, c), Inner, Some(EqualTo(a, a))).select(a, b, c) - checkAnalysis(plan1, expected1) - val plan2 = t1.join(t2, NaturalJoin(LeftOuter), None) - val expected2 = testRelation2.select(a, b).join( - testRelation2.select(a, c), LeftOuter, Some(EqualTo(a, a))).select(a, b, c) - checkAnalysis(plan2, expected2) - val plan3 = t1.join(t2, NaturalJoin(RightOuter), None) - val expected3 = testRelation2.select(a, b).join( - testRelation2.select(a, c), RightOuter, Some(EqualTo(a, a))).select(a, b, c) - checkAnalysis(plan3, expected3) - val plan4 = t1.join(t2, NaturalJoin(FullOuter), None) - val expected4 = testRelation2.select(a, b).join(testRelation2.select( - a, c), FullOuter, Some(EqualTo(a, a))).select(Alias(Coalesce(Seq(a, a)), "a")(), b, c) - checkAnalysis(plan4, expected4) - - val plan5 = tt1.join(tt2, NaturalJoin(Inner), None) - val expected5 = testRelation0.select(aa, bb).join( - testRelation0.select(aa, cc), Inner, Some(EqualTo(aa, aa))).select(aa, bb, cc) - checkAnalysis(plan5, expected5) - val plan6 = tt1.join(tt2, NaturalJoin(LeftOuter), None) - val expected6 = testRelation0.select(aa, bb).join( - testRelation0.select(aa, cc), LeftOuter, Some(EqualTo(aa, aa))).select(aa, bb, truecc) - checkAnalysis(plan6, expected6) - val plan7 = tt1.join(tt2, NaturalJoin(RightOuter), None) - val expected7 = testRelation0.select(aa, bb).join( - testRelation0.select(aa, cc), RightOuter, Some(EqualTo(aa, aa))).select(aa, truebb, cc) - checkAnalysis(plan7, expected7) - val plan8 = tt1.join(tt2, NaturalJoin(FullOuter), None) - val expected8 = testRelation0.select(aa, bb).join( - testRelation0.select(aa, cc), FullOuter, Some(EqualTo(aa, aa))).select( - Alias(Coalesce(Seq(aa, aa)), "a")(), truebb, truecc) - checkAnalysis(plan8, expected8) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala new file mode 100644 index 000000000000..b73238db17ef --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.StringType + +class ResolveNaturalJoinSuite extends AnalysisTest { + import org.apache.spark.sql.catalyst.analysis.TestRelations._ + + val t1 = testRelation2.select('a, 'b) + val t2 = testRelation2.select('a, 'c) + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val testRelation0 = LocalRelation( + AttributeReference("a", StringType, nullable = false)(), + AttributeReference("b", StringType, nullable = false)(), + AttributeReference("c", StringType, nullable = false)()) + val tt1 = testRelation0.select('a, 'b) + val tt2 = testRelation0.select('a, 'c) + val aa = testRelation0.output(0) + val bb = testRelation0.output(1) + val cc = testRelation0.output(2) + val trueB = testRelation0.output(1).withNullability(true) + val trueC = testRelation0.output(2).withNullability(true) + + test("natural inner join") { + val plan = t1.join(t2, NaturalJoin(Inner), None) + val expected = testRelation2.select(a, b).join( + testRelation2.select(a, c), Inner, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural left join") { + val plan = t1.join(t2, NaturalJoin(LeftOuter), None) + val expected = testRelation2.select(a, b).join( + testRelation2.select(a, c), LeftOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural right join") { + val plan = t1.join(t2, NaturalJoin(RightOuter), None) + val expected = testRelation2.select(a, b).join( + testRelation2.select(a, c), RightOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural full outer join") { + val plan = t1.join(t2, NaturalJoin(FullOuter), None) + val expected = testRelation2.select(a, b).join(testRelation2.select( + a, c), FullOuter, Some(EqualTo(a, a))).select(Alias(Coalesce(Seq(a, a)), "a")(), b, c) + checkAnalysis(plan, expected) + } + + test("natural inner join with no nullability") { + val plan = tt1.join(tt2, NaturalJoin(Inner), None) + val expected = testRelation0.select(aa, bb).join( + testRelation0.select(aa, cc), Inner, Some(EqualTo(aa, aa))).select(aa, bb, cc) + checkAnalysis(plan, expected) + } + + test("natural left join with no nullability") { + val plan = tt1.join(tt2, NaturalJoin(LeftOuter), None) + val expected = testRelation0.select(aa, bb).join( + testRelation0.select(aa, cc), LeftOuter, Some(EqualTo(aa, aa))).select(aa, bb, trueC) + checkAnalysis(plan, expected) + } + + test("natural right join with no nullability") { + val plan = tt1.join(tt2, NaturalJoin(RightOuter), None) + val expected = testRelation0.select(aa, bb).join( + testRelation0.select(aa, cc), RightOuter, Some(EqualTo(aa, aa))).select(aa, trueB, cc) + checkAnalysis(plan, expected) + } + + test("natural full outer join with no nullability") { + val plan = tt1.join(tt2, NaturalJoin(FullOuter), None) + val expected = testRelation0.select(aa, bb).join( + testRelation0.select(aa, cc), FullOuter, Some(EqualTo(aa, aa))).select( + Alias(Coalesce(Seq(aa, aa)), "a")(), trueB, trueC) + checkAnalysis(plan, expected) + } +} From 2cf56290545f3faf67eaff824ab9823f33405074 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 29 Jan 2016 01:50:10 -0800 Subject: [PATCH 11/17] lazy val --- .../analysis/ResolveNaturalJoinSuite.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index b73238db17ef..0b831e5ce031 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -27,22 +27,22 @@ import org.apache.spark.sql.types.StringType class ResolveNaturalJoinSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ - val t1 = testRelation2.select('a, 'b) - val t2 = testRelation2.select('a, 'c) - val a = testRelation2.output(0) - val b = testRelation2.output(1) - val c = testRelation2.output(2) - val testRelation0 = LocalRelation( + lazy val t1 = testRelation2.select('a, 'b) + lazy val t2 = testRelation2.select('a, 'c) + lazy val a = testRelation2.output(0) + lazy val b = testRelation2.output(1) + lazy val c = testRelation2.output(2) + lazy val testRelation0 = LocalRelation( AttributeReference("a", StringType, nullable = false)(), AttributeReference("b", StringType, nullable = false)(), AttributeReference("c", StringType, nullable = false)()) - val tt1 = testRelation0.select('a, 'b) - val tt2 = testRelation0.select('a, 'c) - val aa = testRelation0.output(0) - val bb = testRelation0.output(1) - val cc = testRelation0.output(2) - val trueB = testRelation0.output(1).withNullability(true) - val trueC = testRelation0.output(2).withNullability(true) + lazy val tt1 = testRelation0.select('a, 'b) + lazy val tt2 = testRelation0.select('a, 'c) + lazy val aa = testRelation0.output(0) + lazy val bb = testRelation0.output(1) + lazy val cc = testRelation0.output(2) + lazy val trueB = testRelation0.output(1).withNullability(true) + lazy val trueC = testRelation0.output(2).withNullability(true) test("natural inner join") { val plan = t1.join(t2, NaturalJoin(Inner), None) From 192f8bfab03243215c87a6bf2064e860b1606a3f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 29 Jan 2016 18:31:13 +0800 Subject: [PATCH 12/17] Update CheckAnalysis.scala --- .../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4ec9f1fe70a8..f2e78d97442e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.NaturalJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ From 12de061906bf5815845349822a737aea0b704b92 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 31 Jan 2016 18:23:27 -0800 Subject: [PATCH 13/17] improve some doc --- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index f193d91d74cb..7f393b47bcc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -171,9 +171,8 @@ case class Join( def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - // if not a natural join, it is resolved. if it is a natural join, we still need to - // eliminate natural before we mark it resolved, but the node should be ready for - // resolution only if everything else is resolved here. + // Joins are only resolved if they don't introduce ambiguous expression ids. + // NaturalJoin should be ready for resolution only if everything else is resolved here lazy val resolvedExceptNatural: Boolean = { childrenResolved && expressions.forall(_.resolved) && @@ -181,7 +180,8 @@ case class Join( condition.forall(_.dataType == BooleanType) } - // Joins are only resolved if they don't introduce ambiguous expression ids. + // if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need + // to eliminate natural before we mark it resolved. override lazy val resolved: Boolean = joinType match { case NaturalJoin(_) => false case _ => resolvedExceptNatural From b3a6c3259b56db5af4ddf05850b5b9ee0c2789b9 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 31 Jan 2016 21:14:41 -0800 Subject: [PATCH 14/17] use withTempTable --- .../org/apache/spark/sql/SQLQuerySuite.scala | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9b40688b0234..7c507628e925 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2058,22 +2058,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("natural join") { - Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1").registerTempTable("nt1") - Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2").registerTempTable("nt2") - checkAnswer( - sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), - Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) + val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1") + val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2") + withTempTable("nt1", "nt2") { + df1.registerTempTable("nt1") + df2.registerTempTable("nt2") + checkAnswer( + sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), + Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) - checkAnswer( - sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), - Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) + checkAnswer( + sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) - checkAnswer( - sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), - Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) + checkAnswer( + sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) - checkAnswer( - sql("SELECT count(*) FROM nt1 natural full outer join nt2"), - Row(4) :: Nil) + checkAnswer( + sql("SELECT count(*) FROM nt1 natural full outer join nt2"), + Row(4) :: Nil) + } } } From 6aa2a7927baa0cc02bacfb0df6e77744c61bc39a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 31 Jan 2016 21:26:34 -0800 Subject: [PATCH 15/17] rename val --- .../analysis/ResolveNaturalJoinSuite.scala | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index 0b831e5ce031..e466fe1acd89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -33,16 +33,16 @@ class ResolveNaturalJoinSuite extends AnalysisTest { lazy val b = testRelation2.output(1) lazy val c = testRelation2.output(2) lazy val testRelation0 = LocalRelation( - AttributeReference("a", StringType, nullable = false)(), - AttributeReference("b", StringType, nullable = false)(), - AttributeReference("c", StringType, nullable = false)()) - lazy val tt1 = testRelation0.select('a, 'b) - lazy val tt2 = testRelation0.select('a, 'c) - lazy val aa = testRelation0.output(0) - lazy val bb = testRelation0.output(1) - lazy val cc = testRelation0.output(2) - lazy val trueB = testRelation0.output(1).withNullability(true) - lazy val trueC = testRelation0.output(2).withNullability(true) + AttributeReference("d", StringType, nullable = false)(), + AttributeReference("e", StringType, nullable = false)(), + AttributeReference("f", StringType, nullable = false)()) + lazy val t3 = testRelation0.select('d, 'e) + lazy val t4 = testRelation0.select('d, 'f) + lazy val d = testRelation0.output(0) + lazy val e = testRelation0.output(1) + lazy val f = testRelation0.output(2) + lazy val nullableE = testRelation0.output(1).withNullability(true) + lazy val nullableF = testRelation0.output(2).withNullability(true) test("natural inner join") { val plan = t1.join(t2, NaturalJoin(Inner), None) @@ -73,31 +73,30 @@ class ResolveNaturalJoinSuite extends AnalysisTest { } test("natural inner join with no nullability") { - val plan = tt1.join(tt2, NaturalJoin(Inner), None) - val expected = testRelation0.select(aa, bb).join( - testRelation0.select(aa, cc), Inner, Some(EqualTo(aa, aa))).select(aa, bb, cc) + val plan = t3.join(t4, NaturalJoin(Inner), None) + val expected = testRelation0.select(d, e).join( + testRelation0.select(d, f), Inner, Some(EqualTo(d, d))).select(d, e, f) checkAnalysis(plan, expected) } test("natural left join with no nullability") { - val plan = tt1.join(tt2, NaturalJoin(LeftOuter), None) - val expected = testRelation0.select(aa, bb).join( - testRelation0.select(aa, cc), LeftOuter, Some(EqualTo(aa, aa))).select(aa, bb, trueC) + val plan = t3.join(t4, NaturalJoin(LeftOuter), None) + val expected = testRelation0.select(d, e).join( + testRelation0.select(d, f), LeftOuter, Some(EqualTo(d, d))).select(d, e, nullableF) checkAnalysis(plan, expected) } test("natural right join with no nullability") { - val plan = tt1.join(tt2, NaturalJoin(RightOuter), None) - val expected = testRelation0.select(aa, bb).join( - testRelation0.select(aa, cc), RightOuter, Some(EqualTo(aa, aa))).select(aa, trueB, cc) + val plan = t3.join(t4, NaturalJoin(RightOuter), None) + val expected = testRelation0.select(d, e).join( + testRelation0.select(d, f), RightOuter, Some(EqualTo(d, d))).select(d, nullableE, f) checkAnalysis(plan, expected) } test("natural full outer join with no nullability") { - val plan = tt1.join(tt2, NaturalJoin(FullOuter), None) - val expected = testRelation0.select(aa, bb).join( - testRelation0.select(aa, cc), FullOuter, Some(EqualTo(aa, aa))).select( - Alias(Coalesce(Seq(aa, aa)), "a")(), trueB, trueC) + val plan = t3.join(t4, NaturalJoin(FullOuter), None) + val expected = testRelation0.select(d, e).join(testRelation0.select(d, f), FullOuter, Some( + EqualTo(d, d))).select(Alias(Coalesce(Seq(d, d)), "d")(), nullableE, nullableF) checkAnalysis(plan, expected) } } From 42f9d2ceb82038cf9a401e2925366d0ff47a0425 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 31 Jan 2016 23:47:20 -0800 Subject: [PATCH 16/17] rename val --- .../analysis/ResolveNaturalJoinSuite.scala | 70 ++++++++++--------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index e466fe1acd89..459692eea647 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -25,78 +25,80 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.StringType class ResolveNaturalJoinSuite extends AnalysisTest { - import org.apache.spark.sql.catalyst.analysis.TestRelations._ - - lazy val t1 = testRelation2.select('a, 'b) - lazy val t2 = testRelation2.select('a, 'c) - lazy val a = testRelation2.output(0) - lazy val b = testRelation2.output(1) - lazy val c = testRelation2.output(2) - lazy val testRelation0 = LocalRelation( - AttributeReference("d", StringType, nullable = false)(), - AttributeReference("e", StringType, nullable = false)(), - AttributeReference("f", StringType, nullable = false)()) - lazy val t3 = testRelation0.select('d, 'e) - lazy val t4 = testRelation0.select('d, 'f) - lazy val d = testRelation0.output(0) - lazy val e = testRelation0.output(1) - lazy val f = testRelation0.output(2) - lazy val nullableE = testRelation0.output(1).withNullability(true) - lazy val nullableF = testRelation0.output(2).withNullability(true) + lazy val r1 = LocalRelation( + AttributeReference("r1_a", StringType, nullable = true)(), + AttributeReference("r1_b", StringType, nullable = true)(), + AttributeReference("r1_c", StringType, nullable = true)()) + lazy val r2 = LocalRelation( + AttributeReference("r2_a", StringType, nullable = false)(), + AttributeReference("r2_b", StringType, nullable = false)(), + AttributeReference("r2_c", StringType, nullable = false)()) + lazy val t1 = r1.select('r1_a, 'r1_b) + lazy val t2 = r1.select('r1_a, 'r1_c) + lazy val r1a = r1.output(0) + lazy val r1b = r1.output(1) + lazy val r1c = r1.output(2) + lazy val t3 = r2.select('r2_a, 'r2_b) + lazy val t4 = r2.select('r2_a, 'r2_c) + lazy val r2a = r2.output(0) + lazy val r2b = r2.output(1) + lazy val r2c = r2.output(2) + lazy val nullableR2B = r2.output(1).withNullability(true) + lazy val nullableR2C = r2.output(2).withNullability(true) test("natural inner join") { val plan = t1.join(t2, NaturalJoin(Inner), None) - val expected = testRelation2.select(a, b).join( - testRelation2.select(a, c), Inner, Some(EqualTo(a, a))).select(a, b, c) + val expected = r1.select(r1a, r1b).join( + r1.select(r1a, r1c), Inner, Some(EqualTo(r1a, r1a))).select(r1a, r1b, r1c) checkAnalysis(plan, expected) } test("natural left join") { val plan = t1.join(t2, NaturalJoin(LeftOuter), None) - val expected = testRelation2.select(a, b).join( - testRelation2.select(a, c), LeftOuter, Some(EqualTo(a, a))).select(a, b, c) + val expected = r1.select(r1a, r1b).join( + r1.select(r1a, r1c), LeftOuter, Some(EqualTo(r1a, r1a))).select(r1a, r1b, r1c) checkAnalysis(plan, expected) } test("natural right join") { val plan = t1.join(t2, NaturalJoin(RightOuter), None) - val expected = testRelation2.select(a, b).join( - testRelation2.select(a, c), RightOuter, Some(EqualTo(a, a))).select(a, b, c) + val expected = r1.select(r1a, r1b).join( + r1.select(r1a, r1c), RightOuter, Some(EqualTo(r1a, r1a))).select(r1a, r1b, r1c) checkAnalysis(plan, expected) } test("natural full outer join") { val plan = t1.join(t2, NaturalJoin(FullOuter), None) - val expected = testRelation2.select(a, b).join(testRelation2.select( - a, c), FullOuter, Some(EqualTo(a, a))).select(Alias(Coalesce(Seq(a, a)), "a")(), b, c) + val expected = r1.select(r1a, r1b).join(r1.select(r1a, r1c), FullOuter, Some( + EqualTo(r1a, r1a))).select(Alias(Coalesce(Seq(r1a, r1a)), "r1_a")(), r1b, r1c) checkAnalysis(plan, expected) } test("natural inner join with no nullability") { val plan = t3.join(t4, NaturalJoin(Inner), None) - val expected = testRelation0.select(d, e).join( - testRelation0.select(d, f), Inner, Some(EqualTo(d, d))).select(d, e, f) + val expected = r2.select(r2a, r2b).join( + r2.select(r2a, r2c), Inner, Some(EqualTo(r2a, r2a))).select(r2a, r2b, r2c) checkAnalysis(plan, expected) } test("natural left join with no nullability") { val plan = t3.join(t4, NaturalJoin(LeftOuter), None) - val expected = testRelation0.select(d, e).join( - testRelation0.select(d, f), LeftOuter, Some(EqualTo(d, d))).select(d, e, nullableF) + val expected = r2.select(r2a, r2b).join( + r2.select(r2a, r2c), LeftOuter, Some(EqualTo(r2a, r2a))).select(r2a, r2b, nullableR2C) checkAnalysis(plan, expected) } test("natural right join with no nullability") { val plan = t3.join(t4, NaturalJoin(RightOuter), None) - val expected = testRelation0.select(d, e).join( - testRelation0.select(d, f), RightOuter, Some(EqualTo(d, d))).select(d, nullableE, f) + val expected = r2.select(r2a, r2b).join( + r2.select(r2a, r2c), RightOuter, Some(EqualTo(r2a, r2a))).select(r2a, nullableR2B, r2c) checkAnalysis(plan, expected) } test("natural full outer join with no nullability") { val plan = t3.join(t4, NaturalJoin(FullOuter), None) - val expected = testRelation0.select(d, e).join(testRelation0.select(d, f), FullOuter, Some( - EqualTo(d, d))).select(Alias(Coalesce(Seq(d, d)), "d")(), nullableE, nullableF) + val expected = r2.select(r2a, r2b).join(r2.select(r2a, r2c), FullOuter, Some(EqualTo( + r2a, r2a))).select(Alias(Coalesce(Seq(r2a, r2a)), "r2_a")(), nullableR2B, nullableR2C) checkAnalysis(plan, expected) } } From 307cb5e95b891637677bf79f0bf4858cb2ed2bc4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 1 Feb 2016 02:14:40 -0800 Subject: [PATCH 17/17] use dsl --- .../analysis/ResolveNaturalJoinSuite.scala | 76 ++++++++----------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index 459692eea647..a6554fbc414b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -22,83 +22,69 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.StringType class ResolveNaturalJoinSuite extends AnalysisTest { - lazy val r1 = LocalRelation( - AttributeReference("r1_a", StringType, nullable = true)(), - AttributeReference("r1_b", StringType, nullable = true)(), - AttributeReference("r1_c", StringType, nullable = true)()) - lazy val r2 = LocalRelation( - AttributeReference("r2_a", StringType, nullable = false)(), - AttributeReference("r2_b", StringType, nullable = false)(), - AttributeReference("r2_c", StringType, nullable = false)()) - lazy val t1 = r1.select('r1_a, 'r1_b) - lazy val t2 = r1.select('r1_a, 'r1_c) - lazy val r1a = r1.output(0) - lazy val r1b = r1.output(1) - lazy val r1c = r1.output(2) - lazy val t3 = r2.select('r2_a, 'r2_b) - lazy val t4 = r2.select('r2_a, 'r2_c) - lazy val r2a = r2.output(0) - lazy val r2b = r2.output(1) - lazy val r2c = r2.output(2) - lazy val nullableR2B = r2.output(1).withNullability(true) - lazy val nullableR2C = r2.output(2).withNullability(true) + lazy val a = 'a.string + lazy val b = 'b.string + lazy val c = 'c.string + lazy val aNotNull = a.notNull + lazy val bNotNull = b.notNull + lazy val cNotNull = c.notNull + lazy val r1 = LocalRelation(a, b) + lazy val r2 = LocalRelation(a, c) + lazy val r3 = LocalRelation(aNotNull, bNotNull) + lazy val r4 = LocalRelation(bNotNull, cNotNull) test("natural inner join") { - val plan = t1.join(t2, NaturalJoin(Inner), None) - val expected = r1.select(r1a, r1b).join( - r1.select(r1a, r1c), Inner, Some(EqualTo(r1a, r1a))).select(r1a, r1b, r1c) + val plan = r1.join(r2, NaturalJoin(Inner), None) + val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) checkAnalysis(plan, expected) } test("natural left join") { - val plan = t1.join(t2, NaturalJoin(LeftOuter), None) - val expected = r1.select(r1a, r1b).join( - r1.select(r1a, r1c), LeftOuter, Some(EqualTo(r1a, r1a))).select(r1a, r1b, r1c) + val plan = r1.join(r2, NaturalJoin(LeftOuter), None) + val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c) checkAnalysis(plan, expected) } test("natural right join") { - val plan = t1.join(t2, NaturalJoin(RightOuter), None) - val expected = r1.select(r1a, r1b).join( - r1.select(r1a, r1c), RightOuter, Some(EqualTo(r1a, r1a))).select(r1a, r1b, r1c) + val plan = r1.join(r2, NaturalJoin(RightOuter), None) + val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c) checkAnalysis(plan, expected) } test("natural full outer join") { - val plan = t1.join(t2, NaturalJoin(FullOuter), None) - val expected = r1.select(r1a, r1b).join(r1.select(r1a, r1c), FullOuter, Some( - EqualTo(r1a, r1a))).select(Alias(Coalesce(Seq(r1a, r1a)), "r1_a")(), r1b, r1c) + val plan = r1.join(r2, NaturalJoin(FullOuter), None) + val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select( + Alias(Coalesce(Seq(a, a)), "a")(), b, c) checkAnalysis(plan, expected) } test("natural inner join with no nullability") { - val plan = t3.join(t4, NaturalJoin(Inner), None) - val expected = r2.select(r2a, r2b).join( - r2.select(r2a, r2c), Inner, Some(EqualTo(r2a, r2a))).select(r2a, r2b, r2c) + val plan = r3.join(r4, NaturalJoin(Inner), None) + val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, cNotNull) checkAnalysis(plan, expected) } test("natural left join with no nullability") { - val plan = t3.join(t4, NaturalJoin(LeftOuter), None) - val expected = r2.select(r2a, r2b).join( - r2.select(r2a, r2c), LeftOuter, Some(EqualTo(r2a, r2a))).select(r2a, r2b, nullableR2C) + val plan = r3.join(r4, NaturalJoin(LeftOuter), None) + val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, c) checkAnalysis(plan, expected) } test("natural right join with no nullability") { - val plan = t3.join(t4, NaturalJoin(RightOuter), None) - val expected = r2.select(r2a, r2b).join( - r2.select(r2a, r2c), RightOuter, Some(EqualTo(r2a, r2a))).select(r2a, nullableR2B, r2c) + val plan = r3.join(r4, NaturalJoin(RightOuter), None) + val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, a, cNotNull) checkAnalysis(plan, expected) } test("natural full outer join with no nullability") { - val plan = t3.join(t4, NaturalJoin(FullOuter), None) - val expected = r2.select(r2a, r2b).join(r2.select(r2a, r2c), FullOuter, Some(EqualTo( - r2a, r2a))).select(Alias(Coalesce(Seq(r2a, r2a)), "r2_a")(), nullableR2B, nullableR2C) + val plan = r3.join(r4, NaturalJoin(FullOuter), None) + val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select( + Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c) checkAnalysis(plan, expected) } }