Skip to content

Commit 78d18fd

Browse files
marmbrusrxin
authored andcommitted
[SPARK-2658][SQL] Add rule for true = 1.
Author: Michael Armbrust <[email protected]> Closes apache#1556 from marmbrus/fixBooleanEqualsOne and squashes the following commits: ad8edd4 [Michael Armbrust] Add rule for true = 1 and false = 0.
1 parent 9e7725c commit 78d18fd

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,20 @@ trait HiveTypeCoercion {
231231
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
232232
*/
233233
object BooleanComparisons extends Rule[LogicalPlan] {
234+
val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, BigDecimal(1)).map(Literal(_))
235+
val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, BigDecimal(0)).map(Literal(_))
236+
234237
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
235238
// Skip nodes who's children have not been resolved yet.
236239
case e if !e.childrenResolved => e
237-
// No need to change EqualTo operators as that actually makes sense for boolean types.
240+
241+
// Hive treats (true = 1) as true and (false = 0) as true.
242+
case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l
243+
case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r
244+
case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l)
245+
case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r)
246+
247+
// No need to change other EqualTo operators as that actually makes sense for boolean types.
238248
case e: EqualTo => e
239249
// Otherwise turn them to Byte types so that there exists and ordering.
240250
case p: BinaryComparison
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
true true true true true true false false false false false false false false false false false false true true true true true true false false false false false false false false false false false false

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ case class TestData(a: Int, b: String)
3030
*/
3131
class HiveQuerySuite extends HiveComparisonTest {
3232

33+
createQueryTest("boolean = number",
34+
"""
35+
|SELECT
36+
| 1 = true, 1L = true, 1Y = true, true = 1, true = 1L, true = 1Y,
37+
| 0 = true, 0L = true, 0Y = true, true = 0, true = 0L, true = 0Y,
38+
| 1 = false, 1L = false, 1Y = false, false = 1, false = 1L, false = 1Y,
39+
| 0 = false, 0L = false, 0Y = false, false = 0, false = 0L, false = 0Y,
40+
| 2 = true, 2L = true, 2Y = true, true = 2, true = 2L, true = 2Y,
41+
| 2 = false, 2L = false, 2Y = false, false = 2, false = 2L, false = 2Y
42+
|FROM src LIMIT 1
43+
""".stripMargin)
44+
3345
test("CREATE TABLE AS runs once") {
3446
hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect()
3547
assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1,

0 commit comments

Comments
 (0)