Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -803,18 +803,60 @@ object TypeCoercion {
e.copy(left = Cast(e.left, TimestampType))
}

case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
findTightestCommonType(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
// If the expression accepts the tightest common type, cast to that.
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
b.withNewChildren(Seq(newLeft, newRight))
} else {
// Otherwise, don't do anything with the expression.
b
}
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
case b @ BinaryOperator(left, right)
if !BinaryOperator.sameType(left.dataType, right.dataType) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just change this line? The other changes in this file can be done later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

findTightestCommonType doesn't accept struct type with different filed names, so the other code are needed.

(left.dataType, right.dataType) match {
case (StructType(fields1), StructType(fields2)) =>
val commonTypes = scala.collection.mutable.ArrayBuffer.empty[DataType]
val len = fields1.length
var i = 0
var continue = fields1.length == fields2.length
while (i < len && continue) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop could be refactored functionally, e.g.

val commonTypes = (fields1 zip fields2).map(f => findTightestCommonType(f._1, f._2))
if (commonTypes.forall(_.isDefined)) {
 . . .

val commonType = findTightestCommonType(fields1(i).dataType, fields2(i).dataType)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we don't want to support type coercion, we can change this line to use fields1(i).dataType == fields2(i).dataType

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about nested structs?

if (commonType.isDefined) {
commonTypes += commonType.get
} else {
continue = false
}
i += 1
}

if (continue) {
val newLeftST = new StructType(fields1.zip(commonTypes).map {
case (f, commonType) => f.copy(dataType = commonType)
})
val newLeft = if (left.dataType == newLeftST) left else Cast(left, newLeftST)

val newRightST = new StructType(fields2.zip(commonTypes).map {
case (f, commonType) => f.copy(dataType = commonType)
})
val newRight = if (right.dataType == newRightST) right else Cast(right, newRightST)

if (b.inputType.acceptsType(newLeftST) && b.inputType.acceptsType(newRightST)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible b only accepts one side (e.g., only newLeftST) but doesn't accept other side?

b.withNewChildren(Seq(newLeft, newRight))
} else {
// type not acceptable, don't do anything with the expression.
b
}
} else {
// left struct type and right struct type have different number of fields, or some
// fields don't have a common type, don't do anything with the expression.
b
}

case _ =>
findTightestCommonType(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
// If the expression accepts the tightest common type, cast to that.
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ternary operation seems to crop up a few times in this PR. Maybe we can push it out into a method?

private def castIfNeeded(e: Expression, possibleType: DataType): Expression = {
  if (e.dataType == possibleType) data else Cast(e, possibleType)
}

val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
b.withNewChildren(Seq(newLeft, newRight))
} else {
// Otherwise, don't do anything with the expression.
b
}
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
}

case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {

override def checkInputDataTypes(): TypeCheckResult = {
// First check whether left and right have the same type, then check if the type is acceptable.
if (!left.dataType.sameType(right.dataType)) {
if (!BinaryOperator.sameType(left.dataType, right.dataType)) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else if (!inputType.acceptsType(left.dataType)) {
Expand All @@ -595,6 +595,10 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {

object BinaryOperator {
def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right))

def sameType(left: DataType, right: DataType): Boolean = {
DataType.equalsStructurally(left, right, ignoreNullability = true)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ trait HashJoin {
}

protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
require(leftKeys.map(_.dataType).zip(rightKeys.map(_.dataType)).forall {
case (l, r) => BinaryOperator.sameType(l, r)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test case to cover this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hit a test failure before changing this. This kind of check(assert, require) can only be hitten when there is a bug.

}, "Join keys from two sides should have same types")
val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
.map(BindReferences.bindReference(_, right.output))
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/comparator.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
-- test various operator for comparison, including =, <=>, <, <=, >, >=

create temporary view data as select * from values
(1, 1.0D, 'a'),
(2, 2.0D, 'b'),
(3, 3.0D, 'c'),
(null, null, null)
as data(i, j, k);

-- binary type
select x'00' < x'0f';
select x'00' < x'ff';

-- int type
select i, i = 2, i = null, i <=> 2, i <=> null, i < 2, i <= 2, i > 2, i >= 2 from data;

-- decimal type
select j, j = 2.0D, j = null, j <=> 2.0D, j <=> null, j < 2.0D, j <= 2.0D, j > 2.0D, j >= 2.0D from data;

-- string type
select k, k = 'b', k = null, k <=> 'b', k <=> null, k < 'b', k <= 'b', k > 'b', k >= 'b' from data;

-- struct type
select i, j, (i, j) = (2, 2.0D), (i, j) = null, (i, j) < (2, 3.0D) from data;

-- implicit type cast
select i, j, i = 2L, (i, j) = (2L, 2.0D) from data;
80 changes: 74 additions & 6 deletions sql/core/src/test/resources/sql-tests/results/comparator.sql.out
Original file line number Diff line number Diff line change
@@ -1,18 +1,86 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 2
-- Number of queries: 8


-- !query 0
select x'00' < x'0f'
create temporary view data as select * from values
(1, 1.0D, 'a'),
(2, 2.0D, 'b'),
(3, 3.0D, 'c'),
(null, null, null)
as data(i, j, k)
-- !query 0 schema
struct<(X'00' < X'0F'):boolean>
struct<>
-- !query 0 output
true



-- !query 1
select x'00' < x'ff'
select x'00' < x'0f'
-- !query 1 schema
struct<(X'00' < X'FF'):boolean>
struct<(X'00' < X'0F'):boolean>
-- !query 1 output
true


-- !query 2
select x'00' < x'ff'
-- !query 2 schema
struct<(X'00' < X'FF'):boolean>
-- !query 2 output
true


-- !query 3
select i, i = 2, i = null, i <=> 2, i <=> null, i < 2, i <= 2, i > 2, i >= 2 from data
-- !query 3 schema
struct<i:int,(i = 2):boolean,(i = CAST(NULL AS INT)):boolean,(i <=> 2):boolean,(i <=> CAST(NULL AS INT)):boolean,(i < 2):boolean,(i <= 2):boolean,(i > 2):boolean,(i >= 2):boolean>
-- !query 3 output
1 false NULL false false true true false false
2 true NULL true false false true false true
3 false NULL false false false false true true
NULL NULL NULL false true NULL NULL NULL NULL


-- !query 4
select j, j = 2.0D, j = null, j <=> 2.0D, j <=> null, j < 2.0D, j <= 2.0D, j > 2.0D, j >= 2.0D from data
-- !query 4 schema
struct<j:double,(j = 2.0):boolean,(j = CAST(NULL AS DOUBLE)):boolean,(j <=> 2.0):boolean,(j <=> CAST(NULL AS DOUBLE)):boolean,(j < 2.0):boolean,(j <= 2.0):boolean,(j > 2.0):boolean,(j >= 2.0):boolean>
-- !query 4 output
1.0 false NULL false false true true false false
2.0 true NULL true false false true false true
3.0 false NULL false false false false true true
NULL NULL NULL false true NULL NULL NULL NULL


-- !query 5
select k, k = 'b', k = null, k <=> 'b', k <=> null, k < 'b', k <= 'b', k > 'b', k >= 'b' from data
-- !query 5 schema
struct<k:string,(k = b):boolean,(k = CAST(NULL AS STRING)):boolean,(k <=> b):boolean,(k <=> CAST(NULL AS STRING)):boolean,(k < b):boolean,(k <= b):boolean,(k > b):boolean,(k >= b):boolean>
-- !query 5 output
NULL NULL NULL false true NULL NULL NULL NULL
a false NULL false false true true false false
b true NULL true false false true false true
c false NULL false false false false true true


-- !query 6
select i, j, (i, j) = (2, 2.0D), (i, j) = null, (i, j) < (2, 3.0D) from data
-- !query 6 schema
struct<i:int,j:double,(named_struct(i, i, j, j) = named_struct(col1, 2, col2, 2.0)):boolean,(named_struct(i, i, j, j) = NULL):boolean,(named_struct(i, i, j, j) < named_struct(col1, 2, col2, 3.0)):boolean>
-- !query 6 output
1 1.0 false NULL true
2 2.0 true NULL true
3 3.0 false NULL false
NULL NULL false NULL true


-- !query 7
select i, j, i = 2L, (i, j) = (2L, 2.0D) from data
-- !query 7 schema
struct<i:int,j:double,(CAST(i AS BIGINT) = 2):boolean,(named_struct(i, i, j, j) = named_struct(col1, 2, col2, 2.0)):boolean>
-- !query 7 output
1 1.0 false false
2 2.0 true true
3 3.0 false false
NULL NULL NULL false
19 changes: 0 additions & 19 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2675,15 +2675,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))")
sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))")

withTable("t", "S") {
sql("CREATE TABLE t(c struct<f:int>) USING parquet")
sql("CREATE TABLE S(C struct<F:int>) USING parquet")
Seq(("c", "C"), ("C", "c"), ("c.f", "C.F"), ("C.F", "c.f")).foreach {
case (left, right) =>
checkAnswer(sql(s"SELECT * FROM t, S WHERE t.$left = S.$right"), Seq.empty)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test case wants to test set operation, but here it's testing filter. the new tests should've covered it.

}
}
}

withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
Expand All @@ -2696,16 +2687,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))")
}.message
assert(m2.contains("Except can only be performed on tables with the compatible column types"))

withTable("t", "S") {
sql("CREATE TABLE t(c struct<f:int>) USING parquet")
sql("CREATE TABLE S(C struct<F:int>) USING parquet")
checkAnswer(sql("SELECT * FROM t, S WHERE t.c.f = S.C.F"), Seq.empty)
val m = intercept[AnalysisException] {
sql("SELECT * FROM t, S WHERE c = C")
}.message
assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pinging me, @cloud-fan .

Since this removal is a real behavior change instead of new test coverage of comparator.sql, could you add a documentation for this?

}
}
}

Expand Down