From ce4b93c1d7046b7a3e4bee4e924fd012f315f2ef Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 23 Oct 2016 13:57:31 -0400 Subject: [PATCH 1/2] SPARK-18058 --- .../plans/logical/basicLogicalOperators.scala | 28 ++++++------------- .../sql/catalyst/analysis/AnalysisSuite.scala | 19 +++++++++++++ 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 07e39b029894a..3ebe1a06b5535 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -116,6 +116,8 @@ case class Filter(condition: Expression, child: LogicalPlan) abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + protected def leftConstraints: Set[Expression] = left.constraints protected def rightConstraints: Set[Expression] = { @@ -125,6 +127,12 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar case a: Attribute => attributeRewrites(a) }) } + + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && + duplicateResolved } object SetOperation { @@ -133,8 +141,6 @@ object SetOperation { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) @@ -143,14 +149,6 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) - // Intersect are only resolved if they don't introduce ambiguous expression ids, - // since the Optimizer will convert Intersect to Join. - override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && - duplicateResolved - override def maxRows: Option[Long] = { if (children.exists(_.maxRows.isEmpty)) { None @@ -171,19 +169,11 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output override protected def validConstraints: Set[Expression] = leftConstraints - override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && - duplicateResolved - override lazy val statistics: Statistics = { left.statistics.copy() } @@ -218,7 +208,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { child.output.length == children.head.output.length && // compare the data types with the first child child.output.zip(children.head.output).forall { - case (l, r) => l.dataType == r.dataType } + case (l, r) => l.dataType.asNullable == r.dataType.asNullable } ) children.length > 1 && childrenResolved && allChildrenCompatible diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 102c78bd72111..1d50e0d25f7ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -377,4 +377,23 @@ class AnalysisSuite extends AnalysisTest { assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } + + test("SPARK-18058: union and set operations shall not care about the nullability" + + " when comparing column types") { + val firstTable = LocalRelation( + AttributeReference("a", + StructType(Seq(StructField("a", IntegerType, nullable = true))), nullable = false)()) + val secondTable = LocalRelation( + AttributeReference("a", + StructType(Seq(StructField("a", IntegerType, nullable = false))), nullable = false)()) + + val unionPlan = Union(firstTable, secondTable) + assertAnalysisSuccess(unionPlan) + + val r1 = Except(firstTable, secondTable) + val r2 = Intersect(firstTable, secondTable) + + assertAnalysisSuccess(r1) + assertAnalysisSuccess(r2) + } } From e0d806a5dc7698079f188851044b3cc1e0e39bb3 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 23 Oct 2016 13:59:57 -0400 Subject: [PATCH 2/2] SPARK-18058 --- .../plans/logical/basicLogicalOperators.scala | 3 ++- .../spark/sql/catalyst/analysis/AnalysisSuite.scala | 11 +++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ebe1a06b5535..2fdaa057bab18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -131,7 +131,8 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar override lazy val resolved: Boolean = childrenResolved && left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && + left.output.zip(right.output).forall { + case (l, r) => l.dataType.asNullable == r.dataType.asNullable } && duplicateResolved } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1d50e0d25f7ef..1ba4f4a36d1a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -387,13 +387,8 @@ class AnalysisSuite extends AnalysisTest { AttributeReference("a", StructType(Seq(StructField("a", IntegerType, nullable = false))), nullable = false)()) - val unionPlan = Union(firstTable, secondTable) - assertAnalysisSuccess(unionPlan) - - val r1 = Except(firstTable, secondTable) - val r2 = Intersect(firstTable, secondTable) - - assertAnalysisSuccess(r1) - assertAnalysisSuccess(r2) + assertAnalysisSuccess(Union(firstTable, secondTable)) + assertAnalysisSuccess(Except(firstTable, secondTable)) + assertAnalysisSuccess(Intersect(firstTable, secondTable)) } }