-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-24443][SQL] comparison should accept structurally-equal types #21470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -803,18 +803,60 @@ object TypeCoercion { | |
| e.copy(left = Cast(e.left, TimestampType)) | ||
| } | ||
|
|
||
| case b @ BinaryOperator(left, right) if left.dataType != right.dataType => | ||
| findTightestCommonType(left.dataType, right.dataType).map { commonType => | ||
| if (b.inputType.acceptsType(commonType)) { | ||
| // If the expression accepts the tightest common type, cast to that. | ||
| val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) | ||
| val newRight = if (right.dataType == commonType) right else Cast(right, commonType) | ||
| b.withNewChildren(Seq(newLeft, newRight)) | ||
| } else { | ||
| // Otherwise, don't do anything with the expression. | ||
| b | ||
| } | ||
| }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. | ||
| case b @ BinaryOperator(left, right) | ||
| if !BinaryOperator.sameType(left.dataType, right.dataType) => | ||
| (left.dataType, right.dataType) match { | ||
| case (StructType(fields1), StructType(fields2)) => | ||
| val commonTypes = scala.collection.mutable.ArrayBuffer.empty[DataType] | ||
| val len = fields1.length | ||
| var i = 0 | ||
| var continue = fields1.length == fields2.length | ||
| while (i < len && continue) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This loop could be refactored functionally, e.g. |
||
| val commonType = findTightestCommonType(fields1(i).dataType, fields2(i).dataType) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we don't want to support type coercion, we can change this line to use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about nested structs? |
||
| if (commonType.isDefined) { | ||
| commonTypes += commonType.get | ||
| } else { | ||
| continue = false | ||
| } | ||
| i += 1 | ||
| } | ||
|
|
||
| if (continue) { | ||
| val newLeftST = new StructType(fields1.zip(commonTypes).map { | ||
| case (f, commonType) => f.copy(dataType = commonType) | ||
| }) | ||
| val newLeft = if (left.dataType == newLeftST) left else Cast(left, newLeftST) | ||
|
|
||
| val newRightST = new StructType(fields2.zip(commonTypes).map { | ||
| case (f, commonType) => f.copy(dataType = commonType) | ||
| }) | ||
| val newRight = if (right.dataType == newRightST) right else Cast(right, newRightST) | ||
|
|
||
| if (b.inputType.acceptsType(newLeftST) && b.inputType.acceptsType(newRightST)) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible |
||
| b.withNewChildren(Seq(newLeft, newRight)) | ||
| } else { | ||
| // type not acceptable, don't do anything with the expression. | ||
| b | ||
| } | ||
| } else { | ||
| // left struct type and right struct type have different number of fields, or some | ||
| // fields don't have a common type, don't do anything with the expression. | ||
| b | ||
| } | ||
|
|
||
| case _ => | ||
| findTightestCommonType(left.dataType, right.dataType).map { commonType => | ||
| if (b.inputType.acceptsType(commonType)) { | ||
| // If the expression accepts the tightest common type, cast to that. | ||
| val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This ternary operation seems to crop up a few times in this PR. Maybe we can push it out into a method? |
||
| val newRight = if (right.dataType == commonType) right else Cast(right, commonType) | ||
| b.withNewChildren(Seq(newLeft, newRight)) | ||
| } else { | ||
| // Otherwise, don't do anything with the expression. | ||
| b | ||
| } | ||
| }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. | ||
| } | ||
|
|
||
| case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => | ||
| val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,8 +62,9 @@ trait HashJoin { | |
| } | ||
|
|
||
| protected lazy val (buildKeys, streamedKeys) = { | ||
| require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), | ||
| "Join keys from two sides should have same types") | ||
| require(leftKeys.map(_.dataType).zip(rightKeys.map(_.dataType)).forall { | ||
| case (l, r) => BinaryOperator.sameType(l, r) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have a test case to cover this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hit a test failure before changing this. This kind of check(assert, require) can only be hitten when there is a bug. |
||
| }, "Join keys from two sides should have same types") | ||
| val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) | ||
| val rkeys = HashJoin.rewriteKeyExpr(rightKeys) | ||
| .map(BindReferences.bindReference(_, right.output)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,27 @@ | ||
| -- test various operator for comparison, including =, <=>, <, <=, >, >= | ||
|
|
||
| create temporary view data as select * from values | ||
| (1, 1.0D, 'a'), | ||
| (2, 2.0D, 'b'), | ||
| (3, 3.0D, 'c'), | ||
| (null, null, null) | ||
| as data(i, j, k); | ||
|
|
||
| -- binary type | ||
| select x'00' < x'0f'; | ||
| select x'00' < x'ff'; | ||
|
|
||
| -- int type | ||
| select i, i = 2, i = null, i <=> 2, i <=> null, i < 2, i <= 2, i > 2, i >= 2 from data; | ||
|
|
||
| -- decimal type | ||
| select j, j = 2.0D, j = null, j <=> 2.0D, j <=> null, j < 2.0D, j <= 2.0D, j > 2.0D, j >= 2.0D from data; | ||
|
|
||
| -- string type | ||
| select k, k = 'b', k = null, k <=> 'b', k <=> null, k < 'b', k <= 'b', k > 'b', k >= 'b' from data; | ||
|
|
||
| -- struct type | ||
| select i, j, (i, j) = (2, 2.0D), (i, j) = null, (i, j) < (2, 3.0D) from data; | ||
|
|
||
| -- implicit type cast | ||
| select i, j, i = 2L, (i, j) = (2L, 2.0D) from data; |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,86 @@ | ||
| -- Automatically generated by SQLQueryTestSuite | ||
| -- Number of queries: 2 | ||
| -- Number of queries: 8 | ||
|
|
||
|
|
||
| -- !query 0 | ||
| select x'00' < x'0f' | ||
| create temporary view data as select * from values | ||
| (1, 1.0D, 'a'), | ||
| (2, 2.0D, 'b'), | ||
| (3, 3.0D, 'c'), | ||
| (null, null, null) | ||
| as data(i, j, k) | ||
| -- !query 0 schema | ||
| struct<(X'00' < X'0F'):boolean> | ||
| struct<> | ||
| -- !query 0 output | ||
| true | ||
|
|
||
|
|
||
|
|
||
| -- !query 1 | ||
| select x'00' < x'ff' | ||
| select x'00' < x'0f' | ||
| -- !query 1 schema | ||
| struct<(X'00' < X'FF'):boolean> | ||
| struct<(X'00' < X'0F'):boolean> | ||
| -- !query 1 output | ||
| true | ||
|
|
||
|
|
||
| -- !query 2 | ||
| select x'00' < x'ff' | ||
| -- !query 2 schema | ||
| struct<(X'00' < X'FF'):boolean> | ||
| -- !query 2 output | ||
| true | ||
|
|
||
|
|
||
| -- !query 3 | ||
| select i, i = 2, i = null, i <=> 2, i <=> null, i < 2, i <= 2, i > 2, i >= 2 from data | ||
| -- !query 3 schema | ||
| struct<i:int,(i = 2):boolean,(i = CAST(NULL AS INT)):boolean,(i <=> 2):boolean,(i <=> CAST(NULL AS INT)):boolean,(i < 2):boolean,(i <= 2):boolean,(i > 2):boolean,(i >= 2):boolean> | ||
| -- !query 3 output | ||
| 1 false NULL false false true true false false | ||
| 2 true NULL true false false true false true | ||
| 3 false NULL false false false false true true | ||
| NULL NULL NULL false true NULL NULL NULL NULL | ||
|
|
||
|
|
||
| -- !query 4 | ||
| select j, j = 2.0D, j = null, j <=> 2.0D, j <=> null, j < 2.0D, j <= 2.0D, j > 2.0D, j >= 2.0D from data | ||
| -- !query 4 schema | ||
| struct<j:double,(j = 2.0):boolean,(j = CAST(NULL AS DOUBLE)):boolean,(j <=> 2.0):boolean,(j <=> CAST(NULL AS DOUBLE)):boolean,(j < 2.0):boolean,(j <= 2.0):boolean,(j > 2.0):boolean,(j >= 2.0):boolean> | ||
| -- !query 4 output | ||
| 1.0 false NULL false false true true false false | ||
| 2.0 true NULL true false false true false true | ||
| 3.0 false NULL false false false false true true | ||
| NULL NULL NULL false true NULL NULL NULL NULL | ||
|
|
||
|
|
||
| -- !query 5 | ||
| select k, k = 'b', k = null, k <=> 'b', k <=> null, k < 'b', k <= 'b', k > 'b', k >= 'b' from data | ||
| -- !query 5 schema | ||
| struct<k:string,(k = b):boolean,(k = CAST(NULL AS STRING)):boolean,(k <=> b):boolean,(k <=> CAST(NULL AS STRING)):boolean,(k < b):boolean,(k <= b):boolean,(k > b):boolean,(k >= b):boolean> | ||
| -- !query 5 output | ||
| NULL NULL NULL false true NULL NULL NULL NULL | ||
| a false NULL false false true true false false | ||
| b true NULL true false false true false true | ||
| c false NULL false false false false true true | ||
|
|
||
|
|
||
| -- !query 6 | ||
| select i, j, (i, j) = (2, 2.0D), (i, j) = null, (i, j) < (2, 3.0D) from data | ||
| -- !query 6 schema | ||
| struct<i:int,j:double,(named_struct(i, i, j, j) = named_struct(col1, 2, col2, 2.0)):boolean,(named_struct(i, i, j, j) = NULL):boolean,(named_struct(i, i, j, j) < named_struct(col1, 2, col2, 3.0)):boolean> | ||
| -- !query 6 output | ||
| 1 1.0 false NULL true | ||
| 2 2.0 true NULL true | ||
| 3 3.0 false NULL false | ||
| NULL NULL false NULL true | ||
|
|
||
|
|
||
| -- !query 7 | ||
| select i, j, i = 2L, (i, j) = (2L, 2.0D) from data | ||
| -- !query 7 schema | ||
| struct<i:int,j:double,(CAST(i AS BIGINT) = 2):boolean,(named_struct(i, i, j, j) = named_struct(col1, 2, col2, 2.0)):boolean> | ||
| -- !query 7 output | ||
| 1 1.0 false false | ||
| 2 2.0 true true | ||
| 3 3.0 false false | ||
| NULL NULL NULL false |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2675,15 +2675,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | |
| withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { | ||
| sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") | ||
| sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") | ||
|
|
||
| withTable("t", "S") { | ||
| sql("CREATE TABLE t(c struct<f:int>) USING parquet") | ||
| sql("CREATE TABLE S(C struct<F:int>) USING parquet") | ||
| Seq(("c", "C"), ("C", "c"), ("c.f", "C.F"), ("C.F", "c.f")).foreach { | ||
| case (left, right) => | ||
| checkAnswer(sql(s"SELECT * FROM t, S WHERE t.$left = S.$right"), Seq.empty) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test case wants to test set operation, but here it's testing filter. the new tests should've covered it. |
||
| } | ||
| } | ||
| } | ||
|
|
||
| withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { | ||
|
|
@@ -2696,16 +2687,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | |
| sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") | ||
| }.message | ||
| assert(m2.contains("Except can only be performed on tables with the compatible column types")) | ||
|
|
||
| withTable("t", "S") { | ||
| sql("CREATE TABLE t(c struct<f:int>) USING parquet") | ||
| sql("CREATE TABLE S(C struct<F:int>) USING parquet") | ||
| checkAnswer(sql("SELECT * FROM t, S WHERE t.c.f = S.C.F"), Seq.empty) | ||
| val m = intercept[AnalysisException] { | ||
| sql("SELECT * FROM t, S WHERE c = C") | ||
| }.message | ||
| assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch")) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for pinging me, @cloud-fan . Since this removal is a real behavior change instead of new test coverage of |
||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just change this line? The other changes in this file can be done later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
findTightestCommonTypedoesn't accept struct type with different filed names, so the other code are needed.