Skip to content

Commit db81b9d

Browse files
cloud-fanrxin
authored andcommitted
[SPARK-7952][SQL] use internal Decimal instead of java.math.BigDecimal
This PR fixes a bug introduced in #6505. Decimal literal's value is not `java.math.BigDecimal`, but Spark SQL internal type: `Decimal`. Author: Wenchen Fan <[email protected]> Closes #6574 from cloud-fan/fix and squashes the following commits: b0e3549 [Wenchen Fan] rename to BooleanEquality 1987b37 [Wenchen Fan] use Decimal instead of java.math.BigDecimal f93c420 [Wenchen Fan] compare literal
1 parent d6d601a commit db81b9d

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ trait HiveTypeCoercion {
8787
WidenTypes ::
8888
PromoteStrings ::
8989
DecimalPrecision ::
90-
BooleanEqualization ::
90+
BooleanEquality ::
9191
StringToIntegralCasts ::
9292
FunctionArgumentConversion ::
9393
CaseWhenCoercion ::
@@ -479,9 +479,9 @@ trait HiveTypeCoercion {
479479
/**
480480
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
481481
*/
482-
object BooleanEqualization extends Rule[LogicalPlan] {
483-
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1))
484-
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0))
482+
object BooleanEquality extends Rule[LogicalPlan] {
483+
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1))
484+
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0))
485485

486486
private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
487487
CaseKeyWhen(numericExpr, Seq(
@@ -512,22 +512,22 @@ trait HiveTypeCoercion {
512512
// all other cases are considered as false.
513513

514514
// We may simplify the expression if one side is literal numeric values
515-
case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
516-
if trueValues.contains(value) => left
517-
case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
518-
if falseValues.contains(value) => Not(left)
519-
case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
520-
if trueValues.contains(value) => right
521-
case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
522-
if falseValues.contains(value) => Not(right)
523-
case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
524-
if trueValues.contains(value) => And(IsNotNull(left), left)
525-
case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
526-
if falseValues.contains(value) => And(IsNotNull(left), Not(left))
527-
case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
528-
if trueValues.contains(value) => And(IsNotNull(right), right)
529-
case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
530-
if falseValues.contains(value) => And(IsNotNull(right), Not(right))
515+
case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
516+
if trueValues.contains(value) => bool
517+
case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
518+
if falseValues.contains(value) => Not(bool)
519+
case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
520+
if trueValues.contains(value) => bool
521+
case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
522+
if falseValues.contains(value) => Not(bool)
523+
case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
524+
if trueValues.contains(value) => And(IsNotNull(bool), bool)
525+
case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
526+
if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
527+
case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
528+
if trueValues.contains(value) => And(IsNotNull(bool), bool)
529+
case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
530+
if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
531531

532532
case EqualTo(left @ BooleanType(), right @ NumericType()) =>
533533
transform(left , right)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ class HiveTypeCoercionSuite extends PlanTest {
147147
}
148148

149149
test("type coercion simplification for equal to") {
150-
val be = new HiveTypeCoercion {}.BooleanEqualization
150+
val be = new HiveTypeCoercion {}.BooleanEquality
151+
151152
ruleTest(be,
152153
EqualTo(Literal(true), Literal(1)),
153154
Literal(true)
@@ -164,5 +165,26 @@ class HiveTypeCoercionSuite extends PlanTest {
164165
EqualNullSafe(Literal(true), Literal(0)),
165166
And(IsNotNull(Literal(true)), Not(Literal(true)))
166167
)
168+
169+
ruleTest(be,
170+
EqualTo(Literal(true), Literal(1L)),
171+
Literal(true)
172+
)
173+
ruleTest(be,
174+
EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)),
175+
Literal(true)
176+
)
177+
ruleTest(be,
178+
EqualTo(Literal(BigDecimal(0)), Literal(true)),
179+
Not(Literal(true))
180+
)
181+
ruleTest(be,
182+
EqualTo(Literal(Decimal(1)), Literal(true)),
183+
Literal(true)
184+
)
185+
ruleTest(be,
186+
EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)),
187+
Literal(true)
188+
)
167189
}
168190
}

0 commit comments

Comments
 (0)