From 3555e0fd1a8003af0e6a7694aab8999698aab9c4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 1 Jun 2018 02:00:42 +0800 Subject: [PATCH] comparison should accept structurally-equal types --- .../sql/catalyst/analysis/TypeCoercion.scala | 66 ++++++++++++--- .../sql/catalyst/expressions/Expression.scala | 6 +- .../spark/sql/execution/joins/HashJoin.scala | 5 +- .../resources/sql-tests/inputs/comparator.sql | 24 ++++++ .../sql-tests/results/comparator.sql.out | 80 +++++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 19 ----- 6 files changed, 160 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7f..10de40f59bf7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -803,18 +803,60 @@ object TypeCoercion { e.copy(left = Cast(e.left, TimestampType)) } - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType).map { commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + case b @ BinaryOperator(left, right) + if !BinaryOperator.sameType(left.dataType, right.dataType) => + (left.dataType, right.dataType) match { + case (StructType(fields1), StructType(fields2)) => + val commonTypes = scala.collection.mutable.ArrayBuffer.empty[DataType] + val len = fields1.length + var i = 0 + var continue = fields1.length == fields2.length + while (i < len && continue) { + val commonType = findTightestCommonType(fields1(i).dataType, fields2(i).dataType) + if (commonType.isDefined) { + commonTypes += commonType.get + } else { + continue = false + } + i += 1 + } + + if (continue) { + val newLeftST = new StructType(fields1.zip(commonTypes).map { + case (f, commonType) => f.copy(dataType = commonType) + }) + val newLeft = if (left.dataType == newLeftST) left else Cast(left, newLeftST) + + val newRightST = new StructType(fields2.zip(commonTypes).map { + case (f, commonType) => f.copy(dataType = commonType) + }) + val newRight = if (right.dataType == newRightST) right else Cast(right, newRightST) + + if (b.inputType.acceptsType(newLeftST) && b.inputType.acceptsType(newRightST)) { + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // type not acceptable, don't do anything with the expression. + b + } + } else { + // left struct type and right struct type have different number of fields, or some + // fields don't have a common type, don't do anything with the expression. + b + } + + case _ => + findTightestCommonType(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + } case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9b9fa41a47d0f..d4c5112970152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -578,7 +578,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def checkInputDataTypes(): TypeCheckResult = { // First check whether left and right have the same type, then check if the type is acceptable. - if (!left.dataType.sameType(right.dataType)) { + if (!BinaryOperator.sameType(left.dataType, right.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") } else if (!inputType.acceptsType(left.dataType)) { @@ -595,6 +595,10 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { object BinaryOperator { def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) + + def sameType(left: DataType, right: DataType): Boolean = { + DataType.equalsStructurally(left, right, ignoreNullability = true) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 0396168d3f311..a6d619f77a6a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -62,8 +62,9 @@ trait HashJoin { } protected lazy val (buildKeys, streamedKeys) = { - require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), - "Join keys from two sides should have same types") + require(leftKeys.map(_.dataType).zip(rightKeys.map(_.dataType)).forall { + case (l, r) => BinaryOperator.sameType(l, r) + }, "Join keys from two sides should have same types") val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) val rkeys = HashJoin.rewriteKeyExpr(rightKeys) .map(BindReferences.bindReference(_, right.output)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/comparator.sql b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql index 3e2447723e576..ea81cb928b878 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/comparator.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql @@ -1,3 +1,27 @@ +-- test various operator for comparison, including =, <=>, <, <=, >, >= + +create temporary view data as select * from values + (1, 1.0D, 'a'), + (2, 2.0D, 'b'), + (3, 3.0D, 'c'), + (null, null, null) + as data(i, j, k); + -- binary type select x'00' < x'0f'; select x'00' < x'ff'; + +-- int type +select i, i = 2, i = null, i <=> 2, i <=> null, i < 2, i <= 2, i > 2, i >= 2 from data; + +-- decimal type +select j, j = 2.0D, j = null, j <=> 2.0D, j <=> null, j < 2.0D, j <= 2.0D, j > 2.0D, j >= 2.0D from data; + +-- string type +select k, k = 'b', k = null, k <=> 'b', k <=> null, k < 'b', k <= 'b', k > 'b', k >= 'b' from data; + +-- struct type +select i, j, (i, j) = (2, 2.0D), (i, j) = null, (i, j) < (2, 3.0D) from data; + +-- implicit type cast +select i, j, i = 2L, (i, j) = (2L, 2.0D) from data; diff --git a/sql/core/src/test/resources/sql-tests/results/comparator.sql.out b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out index afc7b5448b7b6..9a8b530cc63d8 100644 --- a/sql/core/src/test/resources/sql-tests/results/comparator.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out @@ -1,18 +1,86 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 2 +-- Number of queries: 8 -- !query 0 -select x'00' < x'0f' +create temporary view data as select * from values + (1, 1.0D, 'a'), + (2, 2.0D, 'b'), + (3, 3.0D, 'c'), + (null, null, null) + as data(i, j, k) -- !query 0 schema -struct<(X'00' < X'0F'):boolean> +struct<> -- !query 0 output -true + -- !query 1 -select x'00' < x'ff' +select x'00' < x'0f' -- !query 1 schema -struct<(X'00' < X'FF'):boolean> +struct<(X'00' < X'0F'):boolean> -- !query 1 output true + + +-- !query 2 +select x'00' < x'ff' +-- !query 2 schema +struct<(X'00' < X'FF'):boolean> +-- !query 2 output +true + + +-- !query 3 +select i, i = 2, i = null, i <=> 2, i <=> null, i < 2, i <= 2, i > 2, i >= 2 from data +-- !query 3 schema +struct 2):boolean,(i <=> CAST(NULL AS INT)):boolean,(i < 2):boolean,(i <= 2):boolean,(i > 2):boolean,(i >= 2):boolean> +-- !query 3 output +1 false NULL false false true true false false +2 true NULL true false false true false true +3 false NULL false false false false true true +NULL NULL NULL false true NULL NULL NULL NULL + + +-- !query 4 +select j, j = 2.0D, j = null, j <=> 2.0D, j <=> null, j < 2.0D, j <= 2.0D, j > 2.0D, j >= 2.0D from data +-- !query 4 schema +struct 2.0):boolean,(j <=> CAST(NULL AS DOUBLE)):boolean,(j < 2.0):boolean,(j <= 2.0):boolean,(j > 2.0):boolean,(j >= 2.0):boolean> +-- !query 4 output +1.0 false NULL false false true true false false +2.0 true NULL true false false true false true +3.0 false NULL false false false false true true +NULL NULL NULL false true NULL NULL NULL NULL + + +-- !query 5 +select k, k = 'b', k = null, k <=> 'b', k <=> null, k < 'b', k <= 'b', k > 'b', k >= 'b' from data +-- !query 5 schema +struct b):boolean,(k <=> CAST(NULL AS STRING)):boolean,(k < b):boolean,(k <= b):boolean,(k > b):boolean,(k >= b):boolean> +-- !query 5 output +NULL NULL NULL false true NULL NULL NULL NULL +a false NULL false false true true false false +b true NULL true false false true false true +c false NULL false false false false true true + + +-- !query 6 +select i, j, (i, j) = (2, 2.0D), (i, j) = null, (i, j) < (2, 3.0D) from data +-- !query 6 schema +struct +-- !query 6 output +1 1.0 false NULL true +2 2.0 true NULL true +3 3.0 false NULL false +NULL NULL false NULL true + + +-- !query 7 +select i, j, i = 2L, (i, j) = (2L, 2.0D) from data +-- !query 7 schema +struct +-- !query 7 output +1 1.0 false false +2 2.0 true true +3 3.0 false false +NULL NULL NULL false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 640affc10ee58..4a992fbbfaf41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2675,15 +2675,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") - - withTable("t", "S") { - sql("CREATE TABLE t(c struct) USING parquet") - sql("CREATE TABLE S(C struct) USING parquet") - Seq(("c", "C"), ("C", "c"), ("c.f", "C.F"), ("C.F", "c.f")).foreach { - case (left, right) => - checkAnswer(sql(s"SELECT * FROM t, S WHERE t.$left = S.$right"), Seq.empty) - } - } } withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { @@ -2696,16 +2687,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") }.message assert(m2.contains("Except can only be performed on tables with the compatible column types")) - - withTable("t", "S") { - sql("CREATE TABLE t(c struct) USING parquet") - sql("CREATE TABLE S(C struct) USING parquet") - checkAnswer(sql("SELECT * FROM t, S WHERE t.c.f = S.C.F"), Seq.empty) - val m = intercept[AnalysisException] { - sql("SELECT * FROM t, S WHERE c = C") - }.message - assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch")) - } } }