Skip to content

Commit c4e543d

Browse files
committed
[SPARK-2210] boolean cast on boolean value should be removed.
1 parent 278ec8a commit c4e543d

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ trait HiveTypeCoercion {
251251
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
252252
// Skip nodes who's children have not been resolved yet.
253253
case e if !e.childrenResolved => e
254-
254+
// Skip if the type is boolean type already. Note that this extra cast should be removed
255+
// by optimizer.SimplifyCasts.
256+
case Cast(e, BooleanType) if e.dataType == BooleanType => e
255257
case Cast(e, BooleanType) => Not(Equals(e, Literal(0)))
256258
case Cast(e, dataType) if e.dataType == BooleanType =>
257259
Cast(If(e, Literal(1), Literal(0)), dataType)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717

1818
package org.apache.spark.sql.hive.execution
1919

20+
import org.apache.spark.sql.catalyst.expressions.{Cast, Equals}
21+
import org.apache.spark.sql.execution.Project
22+
import org.apache.spark.sql.hive.test.TestHive
23+
2024
/**
21-
* A set of tests that validate type promotion rules.
25+
* A set of tests that validate type promotion and coercion rules.
2226
*/
2327
class HiveTypeCoercionSuite extends HiveComparisonTest {
2428
val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")
@@ -28,4 +32,23 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
2832
createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
2933
}
3034
}
35+
36+
test("[SPARK-2210] boolean cast on boolean value should be removed") {
37+
val q = "select cast(cast(key=0 as boolean) as boolean) from src"
38+
val project = TestHive.hql(q).queryExecution.executedPlan.collect { case e: Project => e }.head
39+
40+
// No cast expression introduced
41+
project.transformAllExpressions { case c: Cast =>
42+
assert(false, "unexpected cast " + c)
43+
c
44+
}
45+
46+
// Only one Equals
47+
var numEquals = 0
48+
project.transformAllExpressions { case e: Equals =>
49+
numEquals += 1
50+
e
51+
}
52+
assert(numEquals === 1)
53+
}
3154
}

0 commit comments

Comments
 (0)